[
  {
    "path": ".clang-format",
    "content": "# Run the following command to reformat a file:\n# clang-format -i -style=Google <file>\n# Or use clang-format-diff to only reformat the changed lines:\n# https://clang.llvm.org/docs/ClangFormat.html\nBasedOnStyle: Google\nDerivePointerAlignment: false\nColumnLimit:     100\nPointerAlignment: Left\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug-report.md",
    "content": "---\nname: \"🐛 Bug Report\"\nabout: Submit a bug report to help us improve MLC-LLM\ntitle: '[Bug] '\nlabels: ['bug']\nassignees: ''\n\n---\n\n## 🐛 Bug\n\n<!-- A clear and concise description of what the bug is. -->\n\n## To Reproduce\n\nSteps to reproduce the behavior:\n\n1.\n1.\n1.\n\n<!-- If you have a code sample, error messages, stack traces, please provide it here as well -->\n\n## Expected behavior\n\n<!-- A clear and concise description of what you expected to happen. -->\n\n## Environment\n\n - Platform (e.g. WebGPU/Vulkan/IOS/Android/CUDA):\n - Operating system (e.g. Ubuntu/Windows/MacOS/...):\n - Device (e.g. iPhone 12 Pro, PC+RTX 3090, ...)\n - How you installed MLC-LLM (`conda`, source):\n - How you installed TVM (`pip`, source):\n - Python version (e.g. 3.10):\n - GPU driver version (if applicable):\n - CUDA/cuDNN version (if applicable):\n - TVM Hash Tag (`python -c \"import tvm; print('\\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))\"`, applicable if you compile models):\n - Any other relevant information:\n\n## Additional context\n\n<!-- Add any other context about the problem here. -->\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/config.yml",
    "content": "blank_issues_enabled: false\n\ncontact_links:\n  - name: Check the MLC-LLM Documentation\n    url: https://llm.mlc.ai/docs/\n    about: Our documentation might provide answers to your questions.\n  - name: Chat on Discord\n    url: https://discord.gg/9Xpy2HGBuD\n    about: Join the Discord Server to live chat with the community.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/documentation.md",
    "content": "---\nname: \"\\U0001F4DA Documentation\"\nabout: Report an issue related to https://llm.mlc.ai/docs/\ntitle: '[Doc] '\nlabels: ['documentation']\nassignees: ''\n\n---\n\n## 📚 Documentation\n\n### Suggestion\n<!-- Please leave your general suggestion to our documentation here. -->\n\n### Bug\n- Link to the buggy documentation/tutorial:\n- Description of the bug:\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature-request.md",
    "content": "---\nname: \"\\U0001F680 Feature Request\"\nabout: Submit a proposal/request for a new MLC-LLM feature, or an enhancement on existing features.\ntitle: '[Feature Request] '\nlabels: ['feature request']\nassignees: ''\n\n---\n\n## 🚀 Feature\n<!-- A brief description of the feature proposal -->\n\n## Motivation\n\n<!-- Please outline the motivation for the proposal, and how could this feature benefit the MLC-LLM project/community. -->\n\n## Alternatives\n\n<!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->\n\n## Additional context\n\n<!-- Add any other context or screenshots about the feature request here. -->\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/general.md",
    "content": "---\nname: \"❓ General Questions\"\nabout: General questions you have about MLC-LLM.\ntitle: '[Question] '\nlabels: ['question']\nassignees: ''\n\n---\n\n## ❓ General Questions\n\n<!-- Describe your questions -->\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/model-request.md",
    "content": "---\nname: \"️️⚙️  Model Request\"\nabout: Request a new model in MLC-LLM\ntitle: '[Model Request] '\nlabels: ['new-models']\nassignees: ''\n\n---\n\n## ⚙️  Request New Models\n\n- Link to an existing implementation (e.g. Hugging Face/Github): <!-- Link to the model -->\n- Is this model architecture supported by MLC-LLM? (the list of [supported models](https://llm.mlc.ai/docs/prebuilt_models.html)) <!-- Yes/No -->\n\n## Additional context\n\n<!-- Add any other context that you think would be helpful for the community to add this model -->\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/speed-report.md",
    "content": "---\nname: \" 🏎️  Speed Report\"\nabout: Submit a speed report of an model running in MLC-LLM\ntitle: '[Speed] '\nlabels: ['performance']\nassignees: ''\n\n---\n\n# 🏎️  Speed Report\n\n<!-- Please search if there are existing issues discuss the speed of the model you are using, if there are, we encourage you reply in the existed issue instead of creating a new one. -->\n\n- The model code: <!-- e.g. vicuna-7b-1.1 -->\n\n\n- The model configuration (e.g. quantization mode, running data type, etc.):\n- Device (e.g. MacBook Pro M2, PC+RTX 3080):\n- OS (if applicable):\n- Encode speed (Token/s):\n- Decode speed (Token/s):\n- Memory usage (if applicable):\n\n<!-- Note that the measured speed might reflect peak performance if the prompt/chat history is short. -->\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/tracking.md",
    "content": "---\nname: \"Tracking\"\nabout: A tracking issue that tracks ongoing item in the project\ntitle: '[Tracking] '\nlabels: ['status: tracking']\nassignees: ''\n\n---\n\n<!--\n\nA tracking issue contains a list of action items\nthat can be executed to complete a feature or fix.\n\nWe use tracking issues when we have a clear list of action items\nrelated to feature items as they provide fine-grained\nview of action items and provide clarity on what it takes to implement a feature.\n\nWhen to open a tracking issue: Open a new tracking issue when you have\nclear, actionable items (as a rule of thumb, make sure action items\nitems can be carried through if you are assigned to work on it and\nyou can provide enough guides to others who plan to work on these actions).\n-->\n\n\n## Overview\n<!-- A brief overview of the task  -->\n\n\n\n## Action Items\n<!-- Please list set of action items to complete -->\n\n- [ ]\n\n\n## Links to Related Issues and PRs\n\n<!-- Cross link feature requests bug report issues related to the tracking item -->\n<!-- When there are new PRs, open up new PRs -->\n"
  },
  {
    "path": ".github/workflows/documentation.yaml",
    "content": "name: Build Docs\n\non:\n  push:\n    branches:\n      - main\n\njobs:\n  test_linux:\n    name: Deploy Docs\n    runs-on: ubuntu-latest\n\n    steps:\n    - uses: actions/checkout@v4\n      with:\n        submodules: recursive\n\n    - name: Configuring build Environment\n      run: |\n        sudo apt-get update\n        python -m pip install -U pip wheel\n\n    - name: Setup Ruby\n      uses: ruby/setup-ruby@v1\n      with:\n        ruby-version: '3.0'\n\n    - name: Installing dependencies\n      run: |\n        python -m pip install -r docs/requirements.txt\n        gem install jekyll jekyll-remote-theme\n\n    - name: Deploying on GitHub Pages\n      if: github.ref == 'refs/heads/main'\n      run: |\n        git remote set-url origin https://x-access-token:${{ secrets.MLC_GITHUB_TOKEN }}@github.com/$GITHUB_REPOSITORY\n        git config --global user.email \"mlc-gh-actions-bot@nomail\"\n        git config --global user.name \"mlc-gh-actions-bot\"\n        ./scripts/gh_deploy_site.sh\n"
  },
  {
    "path": ".github/workflows/update-relax.yaml",
    "content": "name: 'Relax Submodule Sync'\n\non:\n  workflow_dispatch:\n\njobs:\n  sync:\n    name: 'Relax Submodule Sync'\n    runs-on: ubuntu-latest\n\n    defaults:\n      run:\n        shell: bash\n\n    steps:\n    - name: Checkout\n      uses: actions/checkout@v4\n      with:\n        submodules: true\n\n    - name: Git Sumbodule Update\n      run: |\n        git submodule update --remote 3rdparty/tvm\n\n    - name: Commit update\n      env:\n        GITHUB_TOKEN: ${{ secrets.MLC_GITHUB_TOKEN }}\n      run: |\n        git config --global user.name 'Git bot'\n        git config --global user.email 'bot@noreply.github.com'\n        git remote set-url origin https://$GITHUB_TOKEN@github.com/mlc-ai/mlc-llm\n        git commit -am \"Auto updated submodule references\" && git push || echo \"No changes to commit\"\n"
  },
  {
    "path": ".github/workflows/windows-build.yaml",
    "content": "# GH actions.\n# We use it to cover windows builds\n# Jenkins is still the primary CI\nname: Windows CI\n\non:\n  push:\n    branches:\n      - main\n  pull_request:\n    branches:\n      - main\n\njobs:\n  Windows:\n    runs-on: windows-latest\n    defaults:\n      run:\n        shell: 'cmd /C call {0}'\n\n    steps:\n    - name: Git config\n      run: >-\n        git config --system core.longpaths true\n    - uses: actions/checkout@v3\n      with:\n        submodules: 'recursive'\n    - uses: conda-incubator/setup-miniconda@v3\n      with:\n        activate-environment: mlc-llm-build\n        channel-priority: strict\n        environment-file: ci/build-environment.yaml\n        auto-activate-base: false\n    - name: Conda info\n      run: |\n        conda info\n        conda list\n        python --version\n    - name: Build MLC-LLM\n      run: >-\n        ci/task/build_win.bat\n"
  },
  {
    "path": ".gitignore",
    "content": "tmp/\ndist/\nparams/\ndebug/\n*.bak\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n.DS_Store\n\n*.S\n# C extensions\n*.so\n\nbuild/\n\n*.ll\n.npm\n# Distribution / packaging\n.Python\nenv/\nbuild/\nbuild-*/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n.conda/\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Generated by python/gen_requirements.py\npython/requirements/*.txt\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\ndocs/_staging/\n\n# PyBuilder\ntarget/\n/target/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n*~\n*.pyc\n*~\nconfig.mk\nconfig.cmake\nWin32\n*.dir\nperf\n*.wasm\n.emscripten\n\n## IOS\nDerivedData/\n\n## Java\n*.class\njvm/*/target/\njvm/*/*/target/\njvm/native/*/generated\njvm/native/src/main/native/org_apache_tvm_native_c_api.h\n*.worksheet\n*.idea\n*.iml\n*.classpath\n*.project\n*.settings\n*/node_modules/\n\n## Various settings\n*.pbxuser\n!default.pbxuser\n*.mode1v3\n!default.mode1v3\n*.mode2v3\n!default.mode2v3\n*.perspectivev3\n!default.perspectivev3\nxcuserdata/\n.pkl_memoize_*\n\n.emscripten*\n.m2\n\n# Compiled Dynamic libraries\n*.so\n*.dylib\n*.dll\n\n# Compiled Object files\n*.slo\n*.lo\n*.o\n*.obj\n\n# Precompiled Headers\n*.gch\n*.pch\n\n# Compiled Static libraries\n*.lai\n*.la\n*.a\n*.lib\n\n# Executables\n*.exe\n*.out\n*.app\n\n## Other\n*.moved-aside\n*.xccheckout\n*.xcscmblueprint\n.DS_Store\ntags\ncscope*\n*.lock\n\n# vim temporary files\n*.swp\n*.swo\n\n# TVM generated code\nperf\n.bash_history\n# *.json\n*.params\n*.ro\n*.onnx\n*.h5\nsynset.txt\ncat.jpg\ncat.png\ndocs.tgz\ncat.png\n*.mlmodel\ntvm_u.*\ntvm_t.*\n# Mac OS X\n.DS_Store\n\n# Jetbrain\n.idea\n.ipython\n.jupyter\n.nv\n.pylint.d\n.python_history\n.pytest_cache\n.local\ncmake-build-debug\n\n# Visual Studio\n.vs\n\n# Visual Studio Code\n.vscode\n\n# tmp file\n.nfs*\n\n# keys\n*.pem\n*.p12\n*.pfx\n*.cer\n*.crt\n*.der\n\n# patch sentinel\npatched.txt\n\n# Python type checking\n.mypy_cache/\n.pyre/\n\n# pipenv files\nPipfile\nPipfile.lock\n\n# conda package artifacts\nconda/Dockerfile.cuda*\nconda/pkg\n.node_repl_history\n# nix files\n.envrc\n*.nix\n\n# Docker files\n.sudo_as_admin_successful\n\n# Downloaded models/datasets\n.tvm_test_data\n.dgl\n.caffe2\n\n# Local docs build\n_docs/\njvm/target\n.config/configstore/\n.ci-py-scripts/\n\n# Generated Hexagon files\nsrc/runtime/hexagon/rpc/hexagon_rpc.h\nsrc/runtime/hexagon/rpc/hexagon_rpc_skel.c\nsrc/runtime/hexagon/rpc/hexagon_rpc_stub.c\n\n# Local tvm-site checkout\ntvm-site/\n\n# Generated docs files\ngallery/how_to/work_with_microtvm/micro_tvmc.py\n\n# Test sample data files\n!tests/python/ci/sample_prs/*.json\n\n# Used in CI to communicate between Python and Jenkins\n.docker-image-names/\n\n# Printed TIR code on disk\n*.tir\n\n# GDB history file\n.gdb_history\n\ndist\n"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"3rdparty/argparse\"]\n\tpath = 3rdparty/argparse\n\turl = https://github.com/p-ranav/argparse\n[submodule \"3rdparty/tokenizers-cpp\"]\n\tpath = 3rdparty/tokenizers-cpp\n\turl = https://github.com/mlc-ai/tokenizers-cpp\n[submodule \"3rdparty/googletest\"]\n\tpath = 3rdparty/googletest\n\turl = https://github.com/google/googletest.git\n[submodule \"3rdparty/tvm\"]\n\tpath = 3rdparty/tvm\n\turl = https://github.com/mlc-ai/relax.git\n[submodule \"3rdparty/stb\"]\n\tpath = 3rdparty/stb\n\turl = https://github.com/nothings/stb.git\n[submodule \"3rdparty/xgrammar\"]\n\tpath = 3rdparty/xgrammar\n\turl = https://github.com/mlc-ai/xgrammar.git\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "# To use:\n#\n#     pre-commit run -a\n#\n# Or:\n#\n#     pre-commit install  # (runs every time you commit in git)\n#\n# To update this file:\n#\n#     pre-commit autoupdate\n#\n# See https://github.com/pre-commit/pre-commit\n# Note the pre-commit hooks shoule only be used for formatting, but not for linting.\n# For linting consider using CI.\nrepos:\n  # Standard hooks\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v5.0.0\n    hooks:\n      - id: check-added-large-files\n      - id: check-case-conflict\n      - id: check-merge-conflict\n      - id: check-symlinks\n      - id: end-of-file-fixer\n      - id: mixed-line-ending\n      - id: requirements-txt-fixer\n      - id: trailing-whitespace\n\n  # Changes tabs to spaces\n  - repo: https://github.com/Lucas-C/pre-commit-hooks\n    rev: v1.5.5\n    hooks:\n      - id: remove-tabs\n      - id: remove-crlf\n\n  # Formatters\n  - repo: https://github.com/psf/black-pre-commit-mirror\n    rev: 24.8.0\n    hooks:\n      - id: black\n\n  - repo: https://github.com/pycqa/isort\n    rev: 5.13.2\n    hooks:\n      - id: isort\n\n  - repo: https://github.com/pre-commit/mirrors-clang-format\n    rev: v19.1.1\n    hooks:\n      - id: clang-format\n        types_or: [c++, c, cuda]\n        exclude: |\n          (?x)^(.*cubin.cpp$ | .*fmha_cubin.h | 3rdparty/.*)$\n\n  - repo: https://github.com/cheshirekow/cmake-format-precommit\n    rev: v0.6.13\n    hooks:\n      - id: cmake-format\n        additional_dependencies: [pyyaml>=5.1]\n"
  },
  {
    "path": ".pylintrc",
    "content": "[MESSAGES CONTROL]\ndisable=too-many-positional-arguments,duplicate-code\n"
  },
  {
    "path": "CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.18)\nproject(mlc_llm C CXX)\n\ninclude(CheckCXXCompilerFlag)\nif(MSVC)\n  set(CMAKE_CXX_FLAGS \"/fp:fast ${CMAKE_CXX_FLAGS}\")\nelse()\n  set(CMAKE_CXX_FLAGS \"-ffast-math ${CMAKE_CXX_FLAGS}\")\nendif()\n\nif(EXISTS ${CMAKE_BINARY_DIR}/config.cmake)\n  include(${CMAKE_BINARY_DIR}/config.cmake)\nelse()\n  if(EXISTS ${CMAKE_SOURCE_DIR}/config.cmake)\n    include(${CMAKE_SOURCE_DIR}/config.cmake)\n  endif()\nendif()\n\nif(NOT CMAKE_BUILD_TYPE)\n  set(CMAKE_BUILD_TYPE\n      RelWithDebInfo\n      CACHE STRING \"Build type\" FORCE)\n  message(STATUS \"Setting default build type to \" ${CMAKE_BUILD_TYPE})\nendif(NOT CMAKE_BUILD_TYPE)\n\noption(MLC_HIDE_PRIVATE_SYMBOLS \"Hide private symbols\" ON)\noption(MLC_LLM_BUILD_PYTHON_MODULE \"Build Python module with scikit-build-core\"\n       OFF)\n\nif(MLC_LLM_INSTALL_STATIC_LIB)\n  set(BUILD_STATIC_RUNTIME ON)\nendif()\n\nset(MLC_VISIBILITY_FLAG \"\")\nif(MLC_HIDE_PRIVATE_SYMBOLS)\n  set(HIDE_PRIVATE_SYMBOLS ON)\n  if(NOT MSVC)\n    set(MLC_VISIBILITY_FLAG \"-fvisibility=hidden\")\n  endif()\n  message(STATUS \"Hide private symbols\")\nendif()\n\noption(BUILD_CPP_TEST \"Build cpp unittests\" OFF)\n\nset(CMAKE_CUDA_STANDARD 17)\nset(CMAKE_CXX_STANDARD 17)\nset(CMAKE_POSITION_INDEPENDENT_CODE ON)\n\n# tvm runtime config: minimize runtime components\nset(USE_RPC OFF)\nset(USE_MICRO OFF)\nset(USE_GRAPH_EXECUTOR OFF)\nset(USE_GRAPH_EXECUTOR_DEBUG OFF)\nset(USE_AOT_EXECUTOR OFF)\nset(USE_PROFILER OFF)\nset(USE_GTEST OFF)\nset(USE_LIBBACKTRACE OFF)\nset(BUILD_DUMMY_LIBTVM ON)\nif(NOT DEFINED TVM_SOURCE_DIR)\n  if(DEFINED ENV{TVM_SOURCE_DIR})\n    set(TVM_SOURCE_DIR \"$ENV{TVM_SOURCE_DIR}\")\n  else()\n    set(TVM_SOURCE_DIR 3rdparty/tvm)\n  endif(DEFINED ENV{TVM_SOURCE_DIR})\nendif(NOT DEFINED TVM_SOURCE_DIR)\nmessage(STATUS \"TVM_SOURCE_DIR: ${TVM_SOURCE_DIR}\")\nadd_subdirectory(${TVM_SOURCE_DIR} tvm EXCLUDE_FROM_ALL)\n\nset(MLC_LLM_RUNTIME_LINKER_LIB \"\")\nset(TOKENZIER_CPP_PATH 3rdparty/tokenizers-cpp)\nadd_subdirectory(${TOKENZIER_CPP_PATH} tokenizers EXCLUDE_FROM_ALL)\n\nset(XGRAMMAR_PATH 3rdparty/xgrammar)\ntvm_file_glob(GLOB_RECURSE MLC_LLM_SRCS cpp/*.cc)\ntvm_file_glob(GLOB_RECURSE XGRAMMAR_SRCS ${XGRAMMAR_PATH}/cpp/*.cc)\nlist(FILTER XGRAMMAR_SRCS EXCLUDE REGEX \"${XGRAMMAR_PATH}/cpp/pybind/.*\\\\.cc\")\nlist(APPEND MLC_LLM_SRCS ${XGRAMMAR_SRCS})\nadd_library(mlc_llm_objs OBJECT ${MLC_LLM_SRCS})\n\nset(MLC_LLM_INCLUDES\n    ${TVM_SOURCE_DIR}/include ${TVM_SOURCE_DIR}/3rdparty/dlpack/include)\nset(MLC_LLM_COMPILE_DEFS ${MLC_LLM_COMPILE_DEFS} __STDC_FORMAT_MACROS=1)\nset(MLC_LLM_COMPILE_DEFS ${MLC_LLM_COMPILE_DEFS} XGRAMMAR_ENABLE_LOG_DEBUG=0)\n\ntarget_compile_definitions(mlc_llm_objs PRIVATE ${MLC_LLM_COMPILE_DEFS})\ntarget_compile_definitions(mlc_llm_objs PRIVATE -DMLC_LLM_EXPORTS)\ntarget_include_directories(mlc_llm_objs PRIVATE ${MLC_LLM_INCLUDES})\ntarget_include_directories(mlc_llm_objs PRIVATE 3rdparty/stb)\ntarget_include_directories(mlc_llm_objs PRIVATE ${TOKENZIER_CPP_PATH}/include)\ntarget_include_directories(mlc_llm_objs PRIVATE ${XGRAMMAR_PATH}/include)\n# xgrammar still depends on picojson - use its bundled copy\ntarget_include_directories(mlc_llm_objs\n                           PRIVATE ${XGRAMMAR_PATH}/3rdparty/picojson)\ntarget_link_libraries(mlc_llm_objs PRIVATE tvm_ffi_header)\n\nadd_library(mlc_llm SHARED $<TARGET_OBJECTS:mlc_llm_objs>)\nadd_library(mlc_llm_static STATIC $<TARGET_OBJECTS:mlc_llm_objs>)\nadd_dependencies(mlc_llm_static tokenizers_cpp sentencepiece-static\n                 tokenizers_c tvm_runtime)\nset_target_properties(mlc_llm_static PROPERTIES OUTPUT_NAME mlc_llm)\n\ntarget_link_libraries(mlc_llm PUBLIC tvm_runtime)\ntarget_link_libraries(mlc_llm PRIVATE tokenizers_cpp)\n\nfind_library(FLASH_ATTN_LIBRARY flash_attn\n             HINTS ${TVM_SOURCE_DIR}/*/3rdparty/libflash_attn/src)\n\nif(FLASH_ATTN_LIBRARY STREQUAL \"FLASH_ATTN_LIBRARY-NOTFOUND\")\n  message(\n    WARNING\n      \"Cannot find libflash_attn. The model must not have been built with --use-flash-attn-mqa option.\"\n  )\nelse()\n  target_link_libraries(mlc_llm PUBLIC -Wl,--no-as-needed ${FLASH_ATTN_LIBRARY})\nendif()\n\nif(CMAKE_BUILD_TYPE STREQUAL \"Debug\")\n  target_compile_definitions(mlc_llm PRIVATE \"TVM_LOG_DEBUG\")\n  target_compile_definitions(mlc_llm_objs PRIVATE \"TVM_LOG_DEBUG\")\n  target_compile_definitions(mlc_llm_static PRIVATE \"TVM_LOG_DEBUG\")\nendif()\n\nif(BUILD_CPP_TEST)\n  message(STATUS \"Building cpp unittests\")\n  add_subdirectory(3rdparty/googletest)\n  file(GLOB_RECURSE MLC_LLM_TEST_SRCS\n       ${PROJECT_SOURCE_DIR}/tests/cpp/*unittest.cc)\n  add_executable(mlc_llm_cpp_tests ${MLC_LLM_TEST_SRCS})\n  target_include_directories(mlc_llm_cpp_tests PRIVATE ${MLC_LLM_INCLUDES})\n  target_include_directories(mlc_llm_cpp_tests\n                             PRIVATE ${PROJECT_SOURCE_DIR}/cpp)\n  target_include_directories(\n    mlc_llm_cpp_tests PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})\n  target_link_libraries(mlc_llm_cpp_tests PUBLIC mlc_llm gtest gtest_main)\nendif(BUILD_CPP_TEST)\n\nif(CMAKE_SYSTEM_NAME STREQUAL \"Android\")\n  target_link_libraries(mlc_llm PRIVATE log)\n  target_link_libraries(tokenizers_cpp PRIVATE log)\nendif()\n\nadd_library(mlc_llm_module SHARED $<TARGET_OBJECTS:mlc_llm_objs>)\ntarget_link_libraries(mlc_llm_module PUBLIC tvm)\ntarget_link_libraries(mlc_llm_module PRIVATE tokenizers_cpp)\n\nset_property(\n  TARGET mlc_llm_module\n  APPEND\n  PROPERTY LINK_OPTIONS \"${MLC_VISIBILITY_FLAG}\")\nset_property(\n  TARGET mlc_llm\n  APPEND\n  PROPERTY LINK_OPTIONS \"${MLC_VISIBILITY_FLAG}\")\n\nfind_program(CARGO_EXECUTABLE cargo)\n\nif(NOT CARGO_EXECUTABLE)\n  message(FATAL_ERROR \"Cargo is not found! Please install cargo.\")\nendif()\n\n# when this option is on, we install all static lib deps into lib\nif(MLC_LLM_INSTALL_STATIC_LIB)\n  install(TARGETS mlc_llm_static tokenizers_cpp sentencepiece-static tvm_runtime\n          LIBRARY DESTINATION lib${LIB_SUFFIX})\n  # tokenizers need special handling as it builds from rust\n  if(MSVC)\n    install(FILES ${CMAKE_CURRENT_BINARY_DIR}/tokenizers/libtokenizers_c.lib\n            DESTINATION lib${LIB_SUFFIX})\n  else()\n    install(FILES ${CMAKE_CURRENT_BINARY_DIR}/tokenizers/libtokenizers_c.a\n            DESTINATION lib${LIB_SUFFIX})\n  endif()\nelse()\n  install(\n    TARGETS tvm_runtime\n            mlc_llm\n            mlc_llm_module\n            mlc_llm_static\n            tokenizers_cpp\n            sentencepiece-static\n            RUNTIME_DEPENDENCY_SET\n            tokenizers_c\n    RUNTIME DESTINATION bin\n    LIBRARY DESTINATION lib${LIB_SUFFIX})\nendif()\n\n# Python package installation configuration This section ensures that all\n# necessary files are installed for the Python wheel\nif(MLC_LLM_BUILD_PYTHON_MODULE)\n  message(STATUS \"Configuring Python package installation\")\n\n  # Set RPATH for mlc_llm and mlc_llm_module to find other libraries relatively\n  if(APPLE)\n    # macOS uses @loader_path\n    set_target_properties(mlc_llm PROPERTIES INSTALL_RPATH \"@loader_path\")\n    set_target_properties(mlc_llm_module PROPERTIES INSTALL_RPATH\n                                                    \"@loader_path\")\n  elseif(LINUX)\n    # Linux uses $ORIGIN\n    set_target_properties(mlc_llm PROPERTIES INSTALL_RPATH \"\\$ORIGIN\")\n    set_target_properties(mlc_llm_module PROPERTIES INSTALL_RPATH \"\\$ORIGIN\")\n  endif()\n\n  # Install compiled shared libraries\n  install(TARGETS mlc_llm DESTINATION \".\")\n  install(TARGETS mlc_llm_module DESTINATION \".\")\n  install(DIRECTORY \"${CMAKE_CURRENT_SOURCE_DIR}/cpp/\" DESTINATION \"cpp/\")\n  install(DIRECTORY \"${CMAKE_CURRENT_SOURCE_DIR}/web/\" DESTINATION \"web/\")\n  install(FILES \"${CMAKE_CURRENT_SOURCE_DIR}/README.md\"\n                \"${CMAKE_CURRENT_SOURCE_DIR}/LICENSE\"\n                \"${CMAKE_CURRENT_SOURCE_DIR}/NOTICE\" DESTINATION \".\")\n\n  message(STATUS \"Python package installation configured\")\nendif()\n"
  },
  {
    "path": "CONTRIBUTORS.md",
    "content": "MLC LLM Contributors\n====================\n\n\n## List of Contributors\n- [Full List of Contributors](https://github.com/mlc-ai/mlc-llm/graphs/contributors)\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "NOTICE",
    "content": "MLC LLM\n\nCopyright (c) 2023-2025 by MLC LLM Contributors\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n\n# MLC LLM\n\n[![Installation](https://img.shields.io/badge/docs-latest-green)](https://llm.mlc.ai/docs/)\n[![License](https://img.shields.io/badge/license-apache_2-blue)](https://github.com/mlc-ai/mlc-llm/blob/main/LICENSE)\n[![Join Discoard](https://img.shields.io/badge/Join-Discord-7289DA?logo=discord&logoColor=white)](https://discord.gg/9Xpy2HGBuD)\n[![Related Repository: WebLLM](https://img.shields.io/badge/Related_Repo-WebLLM-fafbfc?logo=github)](https://github.com/mlc-ai/web-llm/)\n\n**Universal LLM Deployment Engine with ML Compilation**\n\n[Get Started](https://llm.mlc.ai/docs/get_started/quick_start) | [Documentation](https://llm.mlc.ai/docs) | [Blog](https://blog.mlc.ai/)\n\n</div>\n\n## About\n\nMLC LLM is a machine learning compiler and high-performance deployment engine for large language models.  The mission of this project is to enable everyone to develop, optimize, and deploy AI models natively on everyone's platforms. \n\n<div align=\"center\">\n<table style=\"width:100%\">\n  <thead>\n    <tr>\n      <th style=\"width:15%\"> </th>\n      <th style=\"width:20%\">AMD GPU</th>\n      <th style=\"width:20%\">NVIDIA GPU</th>\n      <th style=\"width:20%\">Apple GPU</th>\n      <th style=\"width:24%\">Intel GPU</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <td>Linux / Win</td>\n      <td>✅ Vulkan, ROCm</td>\n      <td>✅ Vulkan, CUDA</td>\n      <td>N/A</td>\n      <td>✅ Vulkan</td>\n    </tr>\n    <tr>\n      <td>macOS</td>\n      <td>✅ Metal (dGPU)</td>\n      <td>N/A</td>\n      <td>✅ Metal</td>\n      <td>✅ Metal (iGPU)</td>\n    </tr>\n    <tr>\n      <td>Web Browser</td>\n      <td colspan=4>✅ WebGPU and WASM </td>\n    </tr>\n    <tr>\n      <td>iOS / iPadOS</td>\n      <td colspan=4>✅ Metal on Apple A-series GPU</td>\n    </tr>\n    <tr>\n      <td>Android</td>\n      <td colspan=2>✅ OpenCL on Adreno GPU</td>\n      <td colspan=2>✅ OpenCL on Mali GPU</td>\n    </tr>\n  </tbody>\n</table>\n</div>\n\nMLC LLM compiles and runs code on MLCEngine -- a unified high-performance LLM inference engine across the above platforms. MLCEngine provides OpenAI-compatible API available through REST server, python, javascript, iOS, Android, all backed by the same engine and compiler that we keep improving with the community.\n\n## Get Started\n\nPlease visit our [documentation](https://llm.mlc.ai/docs/) to get started with MLC LLM.\n- [Installation](https://llm.mlc.ai/docs/install/mlc_llm)\n- [Quick start](https://llm.mlc.ai/docs/get_started/quick_start)\n- [Introduction](https://llm.mlc.ai/docs/get_started/introduction)\n\n## Citation\n\nPlease consider citing our project if you find it useful:\n\n```bibtex\n@software{mlc-llm,\n    author = {{MLC team}},\n    title = {{MLC-LLM}},\n    url = {https://github.com/mlc-ai/mlc-llm},\n    year = {2023-2025}\n}\n```\n\nThe underlying techniques of MLC LLM include:\n\n<details>\n  <summary>References (Click to expand)</summary>\n\n  ```bibtex\n  @inproceedings{tensorir,\n      author = {Feng, Siyuan and Hou, Bohan and Jin, Hongyi and Lin, Wuwei and Shao, Junru and Lai, Ruihang and Ye, Zihao and Zheng, Lianmin and Yu, Cody Hao and Yu, Yong and Chen, Tianqi},\n      title = {TensorIR: An Abstraction for Automatic Tensorized Program Optimization},\n      year = {2023},\n      isbn = {9781450399166},\n      publisher = {Association for Computing Machinery},\n      address = {New York, NY, USA},\n      url = {https://doi.org/10.1145/3575693.3576933},\n      doi = {10.1145/3575693.3576933},\n      booktitle = {Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 2},\n      pages = {804–817},\n      numpages = {14},\n      keywords = {Tensor Computation, Machine Learning Compiler, Deep Neural Network},\n      location = {Vancouver, BC, Canada},\n      series = {ASPLOS 2023}\n  }\n\n  @inproceedings{metaschedule,\n      author = {Shao, Junru and Zhou, Xiyou and Feng, Siyuan and Hou, Bohan and Lai, Ruihang and Jin, Hongyi and Lin, Wuwei and Masuda, Masahiro and Yu, Cody Hao and Chen, Tianqi},\n      booktitle = {Advances in Neural Information Processing Systems},\n      editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh},\n      pages = {35783--35796},\n      publisher = {Curran Associates, Inc.},\n      title = {Tensor Program Optimization with Probabilistic Programs},\n      url = {https://proceedings.neurips.cc/paper_files/paper/2022/file/e894eafae43e68b4c8dfdacf742bcbf3-Paper-Conference.pdf},\n      volume = {35},\n      year = {2022}\n  }\n\n  @inproceedings{tvm,\n      author = {Tianqi Chen and Thierry Moreau and Ziheng Jiang and Lianmin Zheng and Eddie Yan and Haichen Shen and Meghan Cowan and Leyuan Wang and Yuwei Hu and Luis Ceze and Carlos Guestrin and Arvind Krishnamurthy},\n      title = {{TVM}: An Automated {End-to-End} Optimizing Compiler for Deep Learning},\n      booktitle = {13th USENIX Symposium on Operating Systems Design and Implementation (OSDI 18)},\n      year = {2018},\n      isbn = {978-1-939133-08-3},\n      address = {Carlsbad, CA},\n      pages = {578--594},\n      url = {https://www.usenix.org/conference/osdi18/presentation/chen},\n      publisher = {USENIX Association},\n      month = oct,\n  }\n  ```\n</details>\n"
  },
  {
    "path": "android/.gitignore",
    "content": "app/src/main/jni/*.h\napp/src/main/jni/*.cc\napp/src/main/obj\n\n*.iml\n.gradle\n/local.properties\n/.idea/caches\n/.idea/libraries\n/.idea/modules.xml\n/.idea/workspace.xml\n/.idea/navEditor.xml\n/.idea/assetWizardSettings.xml\n.DS_Store\n/build\n/captures\n.externalNativeBuild\n.cxx\nlocal.properties\n"
  },
  {
    "path": "android/MLCChat/README.md",
    "content": "# MLC-LLM Android\n\nCheckout [Documentation page](https://llm.mlc.ai/docs/deploy/android.html) for more information.\n\n- run `mlc_llm package`\n- open this `MLCChat/` folder as a project in Android Studio\n"
  },
  {
    "path": "android/MLCChat/app/.gitignore",
    "content": "/build\n/src/main/libs\n"
  },
  {
    "path": "android/MLCChat/app/build.gradle",
    "content": "plugins {\n    id 'com.android.application'\n    id 'org.jetbrains.kotlin.android'\n}\n\nandroid {\n    namespace 'ai.mlc.mlcchat'\n    compileSdk 35\n\n    defaultConfig {\n        applicationId \"ai.mlc.mlcchat\"\n        minSdk 26\n        targetSdk 33\n        versionCode 1\n        versionName \"1.0\"\n\n        testInstrumentationRunner \"androidx.test.runner.AndroidJUnitRunner\"\n        vectorDrawables {\n            useSupportLibrary true\n        }\n    }\n\n    buildTypes {\n        release {\n            minifyEnabled false\n            proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'\n        }\n    }\n    compileOptions {\n        sourceCompatibility JavaVersion.VERSION_1_8\n        targetCompatibility JavaVersion.VERSION_1_8\n    }\n    kotlinOptions {\n        jvmTarget = '1.8'\n    }\n    buildFeatures {\n        compose true\n    }\n    composeOptions {\n        kotlinCompilerExtensionVersion '1.4.3'\n    }\n    packagingOptions {\n        resources {\n            excludes += '/META-INF/{AL2.0,LGPL2.1}'\n        }\n    }\n}\n\ndependencies {\n    implementation project(\":mlc4j\")\n    implementation 'androidx.core:core-ktx:1.10.1'\n    implementation 'androidx.lifecycle:lifecycle-runtime-ktx:2.6.1'\n    implementation 'com.github.jeziellago:compose-markdown:0.5.2'\n    implementation 'androidx.activity:activity-compose:1.7.1'\n    implementation platform('androidx.compose:compose-bom:2022.10.00')\n    implementation 'androidx.lifecycle:lifecycle-viewmodel-compose:2.6.1'\n    implementation 'androidx.compose.ui:ui'\n    implementation 'androidx.compose.ui:ui-graphics'\n    implementation 'androidx.compose.ui:ui-tooling-preview'\n    implementation 'androidx.compose.material3:material3:1.1.0'\n    implementation 'androidx.compose.material:material-icons-extended'\n    implementation 'androidx.appcompat:appcompat:1.6.1'\n    implementation 'androidx.navigation:navigation-compose:2.5.3'\n    implementation 'com.google.code.gson:gson:2.10.1'\n    implementation fileTree(dir: 'src/main/libs', include: ['*.aar', '*.jar'], exclude: [])\n    testImplementation 'junit:junit:4.13.2'\n    androidTestImplementation 'androidx.test.ext:junit:1.1.5'\n    androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1'\n    androidTestImplementation platform('androidx.compose:compose-bom:2022.10.00')\n    androidTestImplementation 'androidx.compose.ui:ui-test-junit4'\n    debugImplementation 'androidx.compose.ui:ui-tooling'\n    debugImplementation 'androidx.compose.ui:ui-test-manifest'\n\n}\n"
  },
  {
    "path": "android/MLCChat/app/proguard-rules.pro",
    "content": "# Add project specific ProGuard rules here.\n# You can control the set of applied configuration files using the\n# proguardFiles setting in build.gradle.\n#\n# For more details, see\n#   http://developer.android.com/guide/developing/tools/proguard.html\n\n# If your project uses WebView with JS, uncomment the following\n# and specify the fully qualified class name to the JavaScript interface\n# class:\n#-keepclassmembers class fqcn.of.javascript.interface.for.webview {\n#   public *;\n#}\n\n# Uncomment this to preserve the line number information for\n# debugging stack traces.\n#-keepattributes SourceFile,LineNumberTable\n\n# If you keep the line number information, uncomment this to\n# hide the original source file name.\n#-renamesourcefileattribute SourceFile\n"
  },
  {
    "path": "android/MLCChat/app/src/main/AndroidManifest.xml",
    "content": "<?xml version=\"1.0\" encoding=\"utf-8\"?>\n<manifest xmlns:android=\"http://schemas.android.com/apk/res/android\"\n    xmlns:tools=\"http://schemas.android.com/tools\"\n    package=\"ai.mlc.mlcchat\">\n\n    <uses-permission android:name=\"android.permission.INTERNET\" />\n    <uses-permission android:name=\"android.permission.READ_MEDIA_IMAGES\" />\n    <uses-permission\n        android:name=\"android.permission.WRITE_EXTERNAL_STORAGE\"\n        android:maxSdkVersion=\"32\"\n        tools:ignore=\"ScopedStorage\" />\n\n    <application\n        android:allowBackup=\"true\"\n        android:dataExtractionRules=\"@xml/data_extraction_rules\"\n        android:fullBackupContent=\"@xml/backup_rules\"\n        android:icon=\"@drawable/mlc_logo_108\"\n        android:label=\"@string/app_name\"\n        android:roundIcon=\"@drawable/mlc_logo_108\"\n        android:supportsRtl=\"true\"\n        android:theme=\"@style/Theme.MLCChat\"\n        tools:targetApi=\"31\">\n        <uses-native-library\n            android:name=\"libOpenCL.so\"\n            android:required=\"false\"/>\n\n        <uses-native-library\n            android:name=\"libOpenCL-pixel.so\"\n            android:required=\"false\" />\n        <activity\n            android:name=\".MainActivity\"\n            android:exported=\"true\"\n            android:label=\"@string/app_name\"\n            android:theme=\"@android:style/Theme.Material.NoActionBar\">\n            <intent-filter>\n                <action android:name=\"android.intent.action.MAIN\" />\n                <category android:name=\"android.intent.category.LAUNCHER\" />\n            </intent-filter>\n        </activity>\n    </application>\n\n</manifest>\n"
  },
  {
    "path": "android/MLCChat/app/src/main/java/ai/mlc/mlcchat/AppViewModel.kt",
    "content": "package ai.mlc.mlcchat\n\nimport ai.mlc.mlcllm.MLCEngine\nimport ai.mlc.mlcllm.OpenAIProtocol\nimport android.app.Application\nimport android.content.ClipData\nimport android.content.ClipboardManager\nimport android.content.Context\nimport android.os.Environment\nimport android.widget.Toast\nimport androidx.compose.runtime.mutableStateOf\nimport androidx.compose.runtime.toMutableStateList\nimport androidx.lifecycle.AndroidViewModel\nimport androidx.lifecycle.viewModelScope\nimport com.google.gson.Gson\nimport com.google.gson.annotations.SerializedName\nimport kotlinx.coroutines.launch\nimport java.io.File\nimport java.io.FileOutputStream\nimport java.net.URL\nimport java.nio.channels.Channels\nimport java.util.UUID\nimport java.util.concurrent.Executors\nimport kotlin.concurrent.thread\nimport ai.mlc.mlcllm.OpenAIProtocol.ChatCompletionMessage\nimport ai.mlc.mlcllm.OpenAIProtocol.ChatCompletionMessageContent\nimport android.app.Activity\nimport kotlinx.coroutines.*\nimport android.graphics.Bitmap\nimport android.graphics.BitmapFactory\nimport android.net.Uri\nimport java.io.ByteArrayOutputStream\nimport android.util.Base64\nimport android.util.Log\n\nclass AppViewModel(application: Application) : AndroidViewModel(application) {\n    val modelList = emptyList<ModelState>().toMutableStateList()\n    val chatState = ChatState()\n    val modelSampleList = emptyList<ModelRecord>().toMutableStateList()\n    private var showAlert = mutableStateOf(false)\n    private var alertMessage = mutableStateOf(\"\")\n    private var appConfig = AppConfig(\n        emptyList<String>().toMutableList(),\n        emptyList<ModelRecord>().toMutableList()\n    )\n    private val application = getApplication<Application>()\n    private val appDirFile = application.getExternalFilesDir(\"\")\n    private val gson = Gson()\n    private val modelIdSet = emptySet<String>().toMutableSet()\n\n    companion object {\n        const val AppConfigFilename = \"mlc-app-config.json\"\n        const val ModelConfigFilename = \"mlc-chat-config.json\"\n        const val ParamsConfigFilename = \"tensor-cache.json\"\n        const val ModelUrlSuffix = \"resolve/main/\"\n    }\n\n    init {\n        loadAppConfig()\n    }\n\n    fun isShowingAlert(): Boolean {\n        return showAlert.value\n    }\n\n    fun errorMessage(): String {\n        return alertMessage.value\n    }\n\n    fun dismissAlert() {\n        require(showAlert.value)\n        showAlert.value = false\n    }\n\n    fun copyError() {\n        require(showAlert.value)\n        val clipboard =\n            application.getSystemService(Context.CLIPBOARD_SERVICE) as ClipboardManager\n        clipboard.setPrimaryClip(ClipData.newPlainText(\"MLCChat\", errorMessage()))\n    }\n\n    private fun issueAlert(error: String) {\n        showAlert.value = true\n        alertMessage.value = error\n    }\n\n    fun requestDeleteModel(modelId: String) {\n        deleteModel(modelId)\n        issueAlert(\"Model: $modelId has been deleted\")\n    }\n\n\n    private fun loadAppConfig() {\n        val appConfigFile = File(appDirFile, AppConfigFilename)\n        val jsonString: String = if (!appConfigFile.exists()) {\n            application.assets.open(AppConfigFilename).bufferedReader().use { it.readText() }\n        } else {\n            appConfigFile.readText()\n        }\n        appConfig = gson.fromJson(jsonString, AppConfig::class.java)\n        appConfig.modelLibs = emptyList<String>().toMutableList()\n        modelList.clear()\n        modelIdSet.clear()\n        modelSampleList.clear()\n        for (modelRecord in appConfig.modelList) {\n            appConfig.modelLibs.add(modelRecord.modelLib)\n            val modelDirFile = File(appDirFile, modelRecord.modelId)\n            val modelConfigFile = File(modelDirFile, ModelConfigFilename)\n            if (modelConfigFile.exists()) {\n                val modelConfigString = modelConfigFile.readText()\n                val modelConfig = gson.fromJson(modelConfigString, ModelConfig::class.java)\n                modelConfig.modelId = modelRecord.modelId\n                modelConfig.modelLib = modelRecord.modelLib\n                modelConfig.estimatedVramBytes = modelRecord.estimatedVramBytes\n                addModelConfig(modelConfig, modelRecord.modelUrl, true)\n            } else {\n                downloadModelConfig(\n                    if (modelRecord.modelUrl.endsWith(\"/\")) modelRecord.modelUrl else \"${modelRecord.modelUrl}/\",\n                    modelRecord,\n                    true\n                )\n            }\n        }\n    }\n\n    private fun updateAppConfig(action: () -> Unit) {\n        action()\n        val jsonString = gson.toJson(appConfig)\n        val appConfigFile = File(appDirFile, AppConfigFilename)\n        appConfigFile.writeText(jsonString)\n    }\n\n    private fun addModelConfig(modelConfig: ModelConfig, modelUrl: String, isBuiltin: Boolean) {\n        require(!modelIdSet.contains(modelConfig.modelId))\n        modelIdSet.add(modelConfig.modelId)\n        modelList.add(\n            ModelState(\n                modelConfig,\n                modelUrl + if (modelUrl.endsWith(\"/\")) \"\" else \"/\",\n                File(appDirFile, modelConfig.modelId)\n            )\n        )\n        if (!isBuiltin) {\n            updateAppConfig {\n                appConfig.modelList.add(\n                    ModelRecord(\n                        modelUrl,\n                        modelConfig.modelId,\n                        modelConfig.estimatedVramBytes,\n                        modelConfig.modelLib\n                    )\n                )\n            }\n        }\n    }\n\n    private fun deleteModel(modelId: String) {\n        val modelDirFile = File(appDirFile, modelId)\n        modelDirFile.deleteRecursively()\n        require(!modelDirFile.exists())\n        modelIdSet.remove(modelId)\n        modelList.removeIf { modelState -> modelState.modelConfig.modelId == modelId }\n        updateAppConfig {\n            appConfig.modelList.removeIf { modelRecord -> modelRecord.modelId == modelId }\n        }\n    }\n\n    private fun isModelConfigAllowed(modelConfig: ModelConfig): Boolean {\n        if (appConfig.modelLibs.contains(modelConfig.modelLib)) return true\n        viewModelScope.launch {\n            issueAlert(\"Model lib ${modelConfig.modelLib} is not supported.\")\n        }\n        return false\n    }\n\n\n    private fun downloadModelConfig(\n        modelUrl: String,\n        modelRecord: ModelRecord,\n        isBuiltin: Boolean\n    ) {\n        thread(start = true) {\n            try {\n                val url = URL(\"${modelUrl}${ModelUrlSuffix}${ModelConfigFilename}\")\n                val tempId = UUID.randomUUID().toString()\n                val tempFile = File(\n                    application.getExternalFilesDir(Environment.DIRECTORY_DOWNLOADS),\n                    tempId\n                )\n                url.openStream().use {\n                    Channels.newChannel(it).use { src ->\n                        FileOutputStream(tempFile).use { fileOutputStream ->\n                            fileOutputStream.channel.transferFrom(src, 0, Long.MAX_VALUE)\n                        }\n                    }\n                }\n                require(tempFile.exists())\n                viewModelScope.launch {\n                    try {\n                        val modelConfigString = tempFile.readText()\n                        val modelConfig = gson.fromJson(modelConfigString, ModelConfig::class.java)\n                        modelConfig.modelId = modelRecord.modelId\n                        modelConfig.modelLib = modelRecord.modelLib\n                        modelConfig.estimatedVramBytes = modelRecord.estimatedVramBytes\n                        if (modelIdSet.contains(modelConfig.modelId)) {\n                            tempFile.delete()\n                            issueAlert(\"${modelConfig.modelId} has been used, please consider another local ID\")\n                            return@launch\n                        }\n                        if (!isModelConfigAllowed(modelConfig)) {\n                            tempFile.delete()\n                            return@launch\n                        }\n                        val modelDirFile = File(appDirFile, modelConfig.modelId)\n                        val modelConfigFile = File(modelDirFile, ModelConfigFilename)\n                        tempFile.copyTo(modelConfigFile, overwrite = true)\n                        tempFile.delete()\n                        require(modelConfigFile.exists())\n                        addModelConfig(modelConfig, modelUrl, isBuiltin)\n                    } catch (e: Exception) {\n                        viewModelScope.launch {\n                            issueAlert(\"Add model failed: ${e.localizedMessage}\")\n                        }\n                    }\n                }\n            } catch (e: Exception) {\n                viewModelScope.launch {\n                    issueAlert(\"Download model config failed: ${e.localizedMessage}\")\n                }\n            }\n\n        }\n    }\n\n    inner class ModelState(\n        val modelConfig: ModelConfig,\n        private val modelUrl: String,\n        private val modelDirFile: File\n    ) {\n        var modelInitState = mutableStateOf(ModelInitState.Initializing)\n        private var paramsConfig = ParamsConfig(emptyList())\n        val progress = mutableStateOf(0)\n        val total = mutableStateOf(1)\n        val id: UUID = UUID.randomUUID()\n        private val remainingTasks = emptySet<DownloadTask>().toMutableSet()\n        private val downloadingTasks = emptySet<DownloadTask>().toMutableSet()\n        private val maxDownloadTasks = 3\n        private val gson = Gson()\n\n\n        init {\n            switchToInitializing()\n        }\n\n        private fun switchToInitializing() {\n            val paramsConfigFile = File(modelDirFile, ParamsConfigFilename)\n            if (paramsConfigFile.exists()) {\n                loadParamsConfig()\n                switchToIndexing()\n            } else {\n                downloadParamsConfig()\n            }\n        }\n\n        private fun loadParamsConfig() {\n            val paramsConfigFile = File(modelDirFile, ParamsConfigFilename)\n            require(paramsConfigFile.exists())\n            val jsonString = paramsConfigFile.readText()\n            paramsConfig = gson.fromJson(jsonString, ParamsConfig::class.java)\n        }\n\n        private fun downloadParamsConfig() {\n            thread(start = true) {\n                val url = URL(\"${modelUrl}${ModelUrlSuffix}${ParamsConfigFilename}\")\n                val tempId = UUID.randomUUID().toString()\n                val tempFile = File(modelDirFile, tempId)\n                url.openStream().use {\n                    Channels.newChannel(it).use { src ->\n                        FileOutputStream(tempFile).use { fileOutputStream ->\n                            fileOutputStream.channel.transferFrom(src, 0, Long.MAX_VALUE)\n                        }\n                    }\n                }\n                require(tempFile.exists())\n                val paramsConfigFile = File(modelDirFile, ParamsConfigFilename)\n                tempFile.renameTo(paramsConfigFile)\n                require(paramsConfigFile.exists())\n                viewModelScope.launch {\n                    loadParamsConfig()\n                    switchToIndexing()\n                }\n            }\n        }\n\n        fun handleStart() {\n            switchToDownloading()\n        }\n\n        fun handlePause() {\n            switchToPausing()\n        }\n\n        fun handleClear() {\n            require(\n                modelInitState.value == ModelInitState.Downloading ||\n                        modelInitState.value == ModelInitState.Paused ||\n                        modelInitState.value == ModelInitState.Finished\n            )\n            switchToClearing()\n        }\n\n        private fun switchToClearing() {\n            if (modelInitState.value == ModelInitState.Paused) {\n                modelInitState.value = ModelInitState.Clearing\n                clear()\n            } else if (modelInitState.value == ModelInitState.Finished) {\n                modelInitState.value = ModelInitState.Clearing\n                if (chatState.modelName.value == modelConfig.modelId) {\n                    chatState.requestTerminateChat { clear() }\n                } else {\n                    clear()\n                }\n            } else {\n                modelInitState.value = ModelInitState.Clearing\n            }\n        }\n\n        fun handleDelete() {\n            require(\n                modelInitState.value == ModelInitState.Downloading ||\n                        modelInitState.value == ModelInitState.Paused ||\n                        modelInitState.value == ModelInitState.Finished\n            )\n            switchToDeleting()\n        }\n\n        private fun switchToDeleting() {\n            if (modelInitState.value == ModelInitState.Paused) {\n                modelInitState.value = ModelInitState.Deleting\n                delete()\n            } else if (modelInitState.value == ModelInitState.Finished) {\n                modelInitState.value = ModelInitState.Deleting\n                if (chatState.modelName.value == modelConfig.modelId) {\n                    chatState.requestTerminateChat { delete() }\n                } else {\n                    delete()\n                }\n            } else {\n                modelInitState.value = ModelInitState.Deleting\n            }\n        }\n\n        private fun switchToIndexing() {\n            modelInitState.value = ModelInitState.Indexing\n            progress.value = 0\n            total.value = modelConfig.tokenizerFiles.size + paramsConfig.paramsRecords.size\n            for (tokenizerFilename in modelConfig.tokenizerFiles) {\n                val file = File(modelDirFile, tokenizerFilename)\n                if (file.exists()) {\n                    ++progress.value\n                } else {\n                    remainingTasks.add(\n                        DownloadTask(\n                            URL(\"${modelUrl}${ModelUrlSuffix}${tokenizerFilename}\"),\n                            file\n                        )\n                    )\n                }\n            }\n            for (paramsRecord in paramsConfig.paramsRecords) {\n                val file = File(modelDirFile, paramsRecord.dataPath)\n                if (file.exists()) {\n                    ++progress.value\n                } else {\n                    remainingTasks.add(\n                        DownloadTask(\n                            URL(\"${modelUrl}${ModelUrlSuffix}${paramsRecord.dataPath}\"),\n                            file\n                        )\n                    )\n                }\n            }\n            if (progress.value < total.value) {\n                switchToPaused()\n            } else {\n                switchToFinished()\n            }\n        }\n\n        private fun switchToDownloading() {\n            modelInitState.value = ModelInitState.Downloading\n            for (downloadTask in remainingTasks) {\n                if (downloadingTasks.size < maxDownloadTasks) {\n                    handleNewDownload(downloadTask)\n                } else {\n                    return\n                }\n            }\n        }\n\n        private fun handleNewDownload(downloadTask: DownloadTask) {\n            require(modelInitState.value == ModelInitState.Downloading)\n            require(!downloadingTasks.contains(downloadTask))\n            downloadingTasks.add(downloadTask)\n            thread(start = true) {\n                val tempId = UUID.randomUUID().toString()\n                val tempFile = File(modelDirFile, tempId)\n                downloadTask.url.openStream().use {\n                    Channels.newChannel(it).use { src ->\n                        FileOutputStream(tempFile).use { fileOutputStream ->\n                            fileOutputStream.channel.transferFrom(src, 0, Long.MAX_VALUE)\n                        }\n                    }\n                }\n                require(tempFile.exists())\n                tempFile.renameTo(downloadTask.file)\n                require(downloadTask.file.exists())\n                viewModelScope.launch {\n                    handleFinishDownload(downloadTask)\n                }\n            }\n        }\n\n        private fun handleNextDownload() {\n            require(modelInitState.value == ModelInitState.Downloading)\n            for (downloadTask in remainingTasks) {\n                if (!downloadingTasks.contains(downloadTask)) {\n                    handleNewDownload(downloadTask)\n                    break\n                }\n            }\n        }\n\n        private fun handleFinishDownload(downloadTask: DownloadTask) {\n            remainingTasks.remove(downloadTask)\n            downloadingTasks.remove(downloadTask)\n            ++progress.value\n            require(\n                modelInitState.value == ModelInitState.Downloading ||\n                        modelInitState.value == ModelInitState.Pausing ||\n                        modelInitState.value == ModelInitState.Clearing ||\n                        modelInitState.value == ModelInitState.Deleting\n            )\n            if (modelInitState.value == ModelInitState.Downloading) {\n                if (remainingTasks.isEmpty()) {\n                    if (downloadingTasks.isEmpty()) {\n                        switchToFinished()\n                    }\n                } else {\n                    handleNextDownload()\n                }\n            } else if (modelInitState.value == ModelInitState.Pausing) {\n                if (downloadingTasks.isEmpty()) {\n                    switchToPaused()\n                }\n            } else if (modelInitState.value == ModelInitState.Clearing) {\n                if (downloadingTasks.isEmpty()) {\n                    clear()\n                }\n            } else if (modelInitState.value == ModelInitState.Deleting) {\n                if (downloadingTasks.isEmpty()) {\n                    delete()\n                }\n            }\n        }\n\n        private fun clear() {\n            val files = modelDirFile.listFiles { dir, name ->\n                !(dir == modelDirFile && name == ModelConfigFilename)\n            }\n            require(files != null)\n            for (file in files) {\n                file.deleteRecursively()\n                require(!file.exists())\n            }\n            val modelConfigFile = File(modelDirFile, ModelConfigFilename)\n            require(modelConfigFile.exists())\n            switchToIndexing()\n        }\n\n        private fun delete() {\n            modelDirFile.deleteRecursively()\n            require(!modelDirFile.exists())\n            requestDeleteModel(modelConfig.modelId)\n        }\n\n        private fun switchToPausing() {\n            modelInitState.value = ModelInitState.Pausing\n        }\n\n        private fun switchToPaused() {\n            modelInitState.value = ModelInitState.Paused\n        }\n\n\n        private fun switchToFinished() {\n            modelInitState.value = ModelInitState.Finished\n        }\n\n        fun startChat() {\n            chatState.requestReloadChat(\n                modelConfig,\n                modelDirFile.absolutePath,\n            )\n        }\n\n    }\n\n    inner class ChatState {\n        val messages = emptyList<MessageData>().toMutableStateList()\n        val report = mutableStateOf(\"\")\n        val modelName = mutableStateOf(\"\")\n        private var modelChatState = mutableStateOf(ModelChatState.Ready)\n            @Synchronized get\n            @Synchronized set\n        private val engine = MLCEngine()\n        private var historyMessages = mutableListOf<ChatCompletionMessage>()\n        private var modelLib = \"\"\n        private var modelPath = \"\"\n        private val executorService = Executors.newSingleThreadExecutor()\n        private val viewModelScope = CoroutineScope(Dispatchers.Main + Job())\n        private var imageUri: Uri? = null\n        private fun mainResetChat() {\n            imageUri = null\n            executorService.submit {\n                callBackend { engine.reset() }\n                historyMessages = mutableListOf<ChatCompletionMessage>()\n                viewModelScope.launch {\n                    clearHistory()\n                    switchToReady()\n                }\n            }\n        }\n\n        private fun clearHistory() {\n            messages.clear()\n            report.value = \"\"\n            historyMessages.clear()\n        }\n\n\n        private fun switchToResetting() {\n            modelChatState.value = ModelChatState.Resetting\n        }\n\n        private fun switchToGenerating() {\n            modelChatState.value = ModelChatState.Generating\n        }\n\n        private fun switchToReloading() {\n            modelChatState.value = ModelChatState.Reloading\n        }\n\n        private fun switchToReady() {\n            modelChatState.value = ModelChatState.Ready\n        }\n\n        private fun switchToFailed() {\n            modelChatState.value = ModelChatState.Falied\n        }\n\n        private fun callBackend(callback: () -> Unit): Boolean {\n            try {\n                callback()\n            } catch (e: Exception) {\n                viewModelScope.launch {\n                    val stackTrace = e.stackTraceToString()\n                    val errorMessage = e.localizedMessage\n                    appendMessage(\n                        MessageRole.Assistant,\n                        \"MLCChat failed\\n\\nStack trace:\\n$stackTrace\\n\\nError message:\\n$errorMessage\"\n                    )\n                    switchToFailed()\n                }\n                return false\n            }\n            return true\n        }\n\n        fun requestResetChat() {\n            require(interruptable())\n            interruptChat(\n                prologue = {\n                    switchToResetting()\n                },\n                epilogue = {\n                    mainResetChat()\n                }\n            )\n        }\n\n        private fun interruptChat(prologue: () -> Unit, epilogue: () -> Unit) {\n            // prologue runs before interruption\n            // epilogue runs after interruption\n            require(interruptable())\n            if (modelChatState.value == ModelChatState.Ready) {\n                prologue()\n                epilogue()\n            } else if (modelChatState.value == ModelChatState.Generating) {\n                prologue()\n                executorService.submit {\n                    viewModelScope.launch { epilogue() }\n                }\n            } else {\n                require(false)\n            }\n        }\n\n        fun requestTerminateChat(callback: () -> Unit) {\n            require(interruptable())\n            interruptChat(\n                prologue = {\n                    switchToTerminating()\n                },\n                epilogue = {\n                    mainTerminateChat(callback)\n                }\n            )\n        }\n\n        private fun mainTerminateChat(callback: () -> Unit) {\n            executorService.submit {\n                callBackend { engine.unload() }\n                viewModelScope.launch {\n                    clearHistory()\n                    switchToReady()\n                    callback()\n                }\n            }\n        }\n\n        private fun switchToTerminating() {\n            modelChatState.value = ModelChatState.Terminating\n        }\n\n\n        fun requestReloadChat(modelConfig: ModelConfig, modelPath: String) {\n\n            if (this.modelName.value == modelConfig.modelId && this.modelLib == modelConfig.modelLib && this.modelPath == modelPath) {\n                return\n            }\n            require(interruptable())\n            interruptChat(\n                prologue = {\n                    switchToReloading()\n                },\n                epilogue = {\n                    mainReloadChat(modelConfig, modelPath)\n                }\n            )\n        }\n\n        private fun mainReloadChat(modelConfig: ModelConfig, modelPath: String) {\n            clearHistory()\n            this.modelName.value = modelConfig.modelId\n            this.modelLib = modelConfig.modelLib\n            this.modelPath = modelPath\n            executorService.submit {\n                viewModelScope.launch {\n                    Toast.makeText(application, \"Initialize...\", Toast.LENGTH_SHORT).show()\n                }\n                if (!callBackend {\n                        engine.unload()\n                        engine.reload(modelPath, modelConfig.modelLib)\n                    }) return@submit\n                viewModelScope.launch {\n                    Toast.makeText(application, \"Ready to chat\", Toast.LENGTH_SHORT).show()\n                    switchToReady()\n                }\n            }\n        }\n\n        fun requestImageBitmap(uri: Uri?) {\n            require(chatable())\n            switchToGenerating()\n            executorService.submit {\n                imageUri = uri\n                viewModelScope.launch {\n                    report.value = \"Image process is done, ask any question.\"\n                    if (modelChatState.value == ModelChatState.Generating) switchToReady()\n                }\n            }\n        }\n\n        fun bitmapToURL(bm: Bitmap): String {\n            val targetSize = 336\n            val scaledBitmap = Bitmap.createScaledBitmap(bm, targetSize, targetSize, true)\n\n            val outputStream = ByteArrayOutputStream()\n            scaledBitmap.compress(Bitmap.CompressFormat.JPEG, 100, outputStream)\n            scaledBitmap.recycle()\n\n            val imageBytes = outputStream.toByteArray()\n            val imageBase64 = Base64.encodeToString(imageBytes, Base64.NO_WRAP)\n            return \"data:image/jpg;base64,$imageBase64\"\n        }\n\n        fun requestGenerate(prompt: String, activity: Activity) {\n            require(chatable())\n            switchToGenerating()\n            appendMessage(MessageRole.User, prompt)\n            appendMessage(MessageRole.Assistant, \"\")\n            var content = ChatCompletionMessageContent(text=prompt)\n            if (imageUri != null) {\n                val uri = imageUri\n                val bitmap = uri?.let {\n                    activity.contentResolver.openInputStream(it)?.use { input ->\n                        BitmapFactory.decodeStream(input)\n                    }\n                }\n                val imageBase64URL = bitmapToURL(bitmap!!)\n                Log.v(\"requestGenerate\", \"image base64 url: $imageBase64URL\")\n                val parts = listOf(\n                    mapOf(\"type\" to \"text\", \"text\" to prompt),\n                    mapOf(\"type\" to \"image_url\", \"image_url\" to imageBase64URL)\n                )\n                content = ChatCompletionMessageContent(parts=parts)\n                imageUri = null\n            }\n\n            executorService.submit {\n                historyMessages.add(ChatCompletionMessage(\n                    role = OpenAIProtocol.ChatCompletionRole.user,\n                    content = content\n                ))\n\n                viewModelScope.launch {\n                    val responses = engine.chat.completions.create(\n                        messages = historyMessages,\n                        stream_options = OpenAIProtocol.StreamOptions(include_usage = true)\n                    )\n\n                    var finishReasonLength = false\n                    var streamingText = \"\"\n\n                    for (res in responses) {\n                        if (!callBackend {\n                            for (choice in res.choices) {\n                                choice.delta.content?.let { content ->\n                                    streamingText += content.asText()\n                                }\n                                choice.finish_reason?.let { finishReason ->\n                                    if (finishReason == \"length\") {\n                                        finishReasonLength = true\n                                    }\n                                }\n                            }\n                            updateMessage(MessageRole.Assistant, streamingText)\n                            res.usage?.let { finalUsage ->\n                                report.value = finalUsage.extra?.asTextLabel() ?: \"\"\n                            }\n                            if (finishReasonLength) {\n                                streamingText += \" [output truncated due to context length limit...]\"\n                                updateMessage(MessageRole.Assistant, streamingText)\n                            }\n                        });\n                    }\n                    if (streamingText.isNotEmpty()) {\n                        historyMessages.add(ChatCompletionMessage(\n                            role = OpenAIProtocol.ChatCompletionRole.assistant,\n                            content = streamingText\n                        ))\n                        streamingText = \"\"\n                    } else {\n                        if (historyMessages.isNotEmpty()) {\n                            historyMessages.removeAt(historyMessages.size - 1)\n                        }\n                    }\n\n                    if (modelChatState.value == ModelChatState.Generating) switchToReady()\n                }\n            }\n        }\n\n        private fun appendMessage(role: MessageRole, text: String) {\n            messages.add(MessageData(role, text))\n        }\n\n\n        private fun updateMessage(role: MessageRole, text: String) {\n            messages[messages.size - 1] = MessageData(role, text)\n        }\n\n        fun chatable(): Boolean {\n            return modelChatState.value == ModelChatState.Ready\n        }\n\n        fun interruptable(): Boolean {\n            return modelChatState.value == ModelChatState.Ready\n                    || modelChatState.value == ModelChatState.Generating\n                    || modelChatState.value == ModelChatState.Falied\n        }\n    }\n}\n\nenum class ModelInitState {\n    Initializing,\n    Indexing,\n    Paused,\n    Downloading,\n    Pausing,\n    Clearing,\n    Deleting,\n    Finished\n}\n\nenum class ModelChatState {\n    Generating,\n    Resetting,\n    Reloading,\n    Terminating,\n    Ready,\n    Falied\n}\n\nenum class MessageRole {\n    Assistant,\n    User\n}\n\ndata class DownloadTask(val url: URL, val file: File)\n\ndata class MessageData(val role: MessageRole, val text: String, val id: UUID = UUID.randomUUID(), var imageUri: Uri? = null)\n\ndata class AppConfig(\n    @SerializedName(\"model_libs\") var modelLibs: MutableList<String>,\n    @SerializedName(\"model_list\") val modelList: MutableList<ModelRecord>,\n)\n\ndata class ModelRecord(\n    @SerializedName(\"model_url\") val modelUrl: String,\n    @SerializedName(\"model_id\") val modelId: String,\n    @SerializedName(\"estimated_vram_bytes\") val estimatedVramBytes: Long?,\n    @SerializedName(\"model_lib\") val modelLib: String\n)\n\ndata class ModelConfig(\n    @SerializedName(\"model_lib\") var modelLib: String,\n    @SerializedName(\"model_id\") var modelId: String,\n    @SerializedName(\"estimated_vram_bytes\") var estimatedVramBytes: Long?,\n    @SerializedName(\"tokenizer_files\") val tokenizerFiles: List<String>,\n    @SerializedName(\"context_window_size\") val contextWindowSize: Int,\n    @SerializedName(\"prefill_chunk_size\") val prefillChunkSize: Int,\n)\n\ndata class ParamsRecord(\n    @SerializedName(\"dataPath\") val dataPath: String\n)\n\ndata class ParamsConfig(\n    @SerializedName(\"records\") val paramsRecords: List<ParamsRecord>\n)\n"
  },
  {
    "path": "android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ChatView.kt",
    "content": "package ai.mlc.mlcchat\n\nimport android.app.Activity\nimport android.graphics.Bitmap\nimport android.graphics.BitmapFactory\nimport androidx.compose.foundation.Image\nimport androidx.compose.foundation.background\nimport androidx.compose.foundation.gestures.detectTapGestures\nimport androidx.compose.foundation.layout.Arrangement\nimport androidx.compose.foundation.layout.Column\nimport androidx.compose.foundation.layout.IntrinsicSize\nimport androidx.compose.foundation.layout.Row\nimport androidx.compose.foundation.layout.aspectRatio\nimport androidx.compose.foundation.layout.fillMaxSize\nimport androidx.compose.foundation.layout.fillMaxWidth\nimport androidx.compose.foundation.layout.height\nimport androidx.compose.foundation.layout.padding\nimport androidx.compose.foundation.layout.widthIn\nimport androidx.compose.foundation.layout.wrapContentHeight\nimport androidx.compose.foundation.layout.wrapContentWidth\nimport androidx.compose.foundation.lazy.LazyColumn\nimport androidx.compose.foundation.lazy.items\nimport androidx.compose.foundation.lazy.rememberLazyListState\nimport androidx.compose.foundation.shape.RoundedCornerShape\nimport androidx.compose.foundation.text.selection.SelectionContainer\nimport androidx.compose.material.icons.Icons\nimport androidx.compose.material.icons.filled.AddAPhoto\nimport androidx.compose.material.icons.filled.ArrowBack\nimport androidx.compose.material.icons.filled.Photo\nimport androidx.compose.material.icons.filled.Replay\nimport androidx.compose.material.icons.filled.Send\nimport androidx.compose.material3.Divider\nimport androidx.compose.material3.ExperimentalMaterial3Api\nimport androidx.compose.material3.Icon\nimport androidx.compose.material3.IconButton\nimport androidx.compose.material3.MaterialTheme\nimport androidx.compose.material3.OutlinedTextField\nimport androidx.compose.material3.Scaffold\nimport androidx.compose.material3.Switch\nimport androidx.compose.material3.Text\nimport androidx.compose.material3.TopAppBar\nimport androidx.compose.material3.TopAppBarDefaults\nimport androidx.compose.runtime.Composable\nimport androidx.compose.runtime.getValue\nimport androidx.compose.runtime.mutableStateOf\nimport androidx.compose.runtime.remember\nimport androidx.compose.runtime.rememberCoroutineScope\nimport androidx.compose.runtime.saveable.rememberSaveable\nimport androidx.compose.runtime.setValue\nimport androidx.compose.ui.Alignment\nimport androidx.compose.ui.Modifier\nimport androidx.compose.ui.graphics.asImageBitmap\nimport androidx.compose.ui.input.pointer.pointerInput\nimport androidx.compose.ui.platform.LocalFocusManager\nimport androidx.compose.ui.text.style.TextAlign\nimport androidx.compose.ui.tooling.preview.Preview\nimport androidx.compose.ui.unit.dp\nimport androidx.navigation.NavController\nimport dev.jeziellago.compose.markdowntext.MarkdownText\nimport kotlinx.coroutines.launch\n\n@ExperimentalMaterial3Api\n@Composable\nfun ChatView(\n    navController: NavController, chatState: AppViewModel.ChatState, activity: Activity\n) {\n    val localFocusManager = LocalFocusManager.current\n    (activity as MainActivity).chatState = chatState\n    Scaffold(topBar = {\n        TopAppBar(\n            title = {\n                Text(\n                    text = \"MLCChat: \" + chatState.modelName.value.split(\"-\")[0],\n                    color = MaterialTheme.colorScheme.onPrimary\n                )\n            },\n            colors = TopAppBarDefaults.topAppBarColors(containerColor = MaterialTheme.colorScheme.primary),\n            navigationIcon = {\n                IconButton(\n                    onClick = { navController.popBackStack() },\n                    enabled = chatState.interruptable()\n                ) {\n                    Icon(\n                        imageVector = Icons.Filled.ArrowBack,\n                        contentDescription = \"back home page\",\n                        tint = MaterialTheme.colorScheme.onPrimary\n                    )\n                }\n            },\n            actions = {\n                IconButton(\n                    onClick = {\n                        chatState.requestResetChat()\n                        activity.hasImage = false },\n                    enabled = chatState.interruptable()\n                ) {\n                    Icon(\n                        imageVector = Icons.Filled.Replay,\n                        contentDescription = \"reset the chat\",\n                        tint = MaterialTheme.colorScheme.onPrimary\n                    )\n                }\n            })\n    }, modifier = Modifier.pointerInput(Unit) {\n        detectTapGestures(onTap = {\n            localFocusManager.clearFocus()\n        })\n    }) { paddingValues ->\n        Column(\n            modifier = Modifier\n                .fillMaxSize()\n                .padding(paddingValues)\n                .padding(horizontal = 10.dp)\n        ) {\n            val lazyColumnListState = rememberLazyListState()\n            val coroutineScope = rememberCoroutineScope()\n            Text(\n                text = chatState.report.value,\n                textAlign = TextAlign.Center,\n                modifier = Modifier\n                    .fillMaxWidth()\n                    .wrapContentHeight()\n                    .padding(top = 5.dp)\n            )\n            Divider(thickness = 1.dp, modifier = Modifier.padding(vertical = 5.dp))\n            LazyColumn(\n                modifier = Modifier.weight(9f),\n                verticalArrangement = Arrangement.spacedBy(5.dp, alignment = Alignment.Bottom),\n                state = lazyColumnListState\n            ) {\n                coroutineScope.launch {\n                    lazyColumnListState.animateScrollToItem(chatState.messages.size)\n                }\n                items(\n                    items = chatState.messages,\n                    key = { message -> message.id },\n                ) { message ->\n                    MessageView(messageData = message, activity)\n                }\n                item {\n                    // place holder item for scrolling to the bottom\n                }\n            }\n            Divider(thickness = 1.dp, modifier = Modifier.padding(top = 5.dp))\n            SendMessageView(chatState = chatState, activity)\n        }\n    }\n}\n\n@Composable\nfun MessageView(messageData: MessageData, activity: Activity?) {\n    // default render the Assistant text as MarkdownText\n    var useMarkdown by remember { mutableStateOf(true) }\n    var localActivity : MainActivity = activity as MainActivity\n    SelectionContainer {\n        if (messageData.role == MessageRole.Assistant) {\n            Column {\n                if (messageData.text.isNotEmpty()) {\n                    Row(\n                        verticalAlignment = Alignment.CenterVertically,\n                    ) {\n                        Text(\n                            text = \"Show as Markdown\",\n                            color = MaterialTheme.colorScheme.onSecondaryContainer,\n                            modifier = Modifier\n                                .wrapContentWidth()\n                                .padding(end = 8.dp)\n                                .widthIn(max = 300.dp)\n                        )\n                        Switch(\n                            checked = useMarkdown,\n                            onCheckedChange = { useMarkdown = it }\n                        )\n                    }\n                }\n                Row(\n                    horizontalArrangement = Arrangement.Start,\n                    modifier = Modifier.fillMaxWidth()\n                ) {\n                    if (useMarkdown) {\n                        MarkdownText(\n                            isTextSelectable = true,\n                            modifier = Modifier\n                                .wrapContentWidth()\n                                .background(\n                                    color = MaterialTheme.colorScheme.secondaryContainer,\n                                    shape = RoundedCornerShape(5.dp)\n                                )\n                                .padding(5.dp)\n                                .widthIn(max = 300.dp),\n                            markdown = messageData.text,\n                        )\n                    } else {\n                        Text(\n                            text = messageData.text,\n                            textAlign = TextAlign.Left,\n                            color = MaterialTheme.colorScheme.onSecondaryContainer,\n                            modifier = Modifier\n                                .wrapContentWidth()\n                                .background(\n                                    color = MaterialTheme.colorScheme.secondaryContainer,\n                                    shape = RoundedCornerShape(5.dp)\n                                )\n                                .padding(5.dp)\n                                .widthIn(max = 300.dp)\n                        )\n                    }\n                }\n            }\n        } else {\n            Row(\n                horizontalArrangement = Arrangement.End,\n                modifier = Modifier.fillMaxWidth()\n            ) {\n                if (messageData.imageUri != null) {\n                    val uri = messageData.imageUri\n                    val bitmap = uri?.let {\n                        activity.contentResolver.openInputStream(it)?.use { input ->\n                            BitmapFactory.decodeStream(input)\n                        }\n                    }\n                    val displayBitmap = bitmap?.let { Bitmap.createScaledBitmap(it, 224, 224, true) }\n                    if (displayBitmap != null) {\n                        Image(\n                            displayBitmap.asImageBitmap(),\n                            \"\",\n                            modifier = Modifier\n                                .wrapContentWidth()\n                                .background(\n                                    color = MaterialTheme.colorScheme.secondaryContainer,\n                                    shape = RoundedCornerShape(5.dp)\n                                )\n                                .padding(5.dp)\n                                .widthIn(max = 300.dp)\n                        )\n                    }\n                    if (!localActivity.hasImage) {\n                        localActivity.chatState.requestImageBitmap(messageData.imageUri)\n                    }\n                    localActivity.hasImage = true\n                } else {\n                    Text(\n                        text = messageData.text,\n                        textAlign = TextAlign.Right,\n                        color = MaterialTheme.colorScheme.onPrimaryContainer,\n                        modifier = Modifier\n                            .wrapContentWidth()\n                            .background(\n                                color = MaterialTheme.colorScheme.primaryContainer,\n                                shape = RoundedCornerShape(5.dp)\n                            )\n                            .padding(5.dp)\n                            .widthIn(max = 300.dp)\n                    )\n                }\n\n            }\n        }\n    }\n}\n\n@ExperimentalMaterial3Api\n@Composable\nfun SendMessageView(chatState: AppViewModel.ChatState, activity: Activity) {\n    val localFocusManager = LocalFocusManager.current\n    val localActivity : MainActivity = activity as MainActivity\n    Row(\n        horizontalArrangement = Arrangement.spacedBy(5.dp),\n        verticalAlignment = Alignment.CenterVertically,\n        modifier = Modifier\n            .height(IntrinsicSize.Max)\n            .fillMaxWidth()\n            .padding(bottom = 5.dp)\n    ) {\n        var text by rememberSaveable { mutableStateOf(\"\") }\n        OutlinedTextField(\n            value = text,\n            onValueChange = { text = it },\n            label = { Text(text = \"Input\") },\n            modifier = Modifier\n                .weight(9f),\n        )\n        IconButton(\n            onClick = {\n                activity.takePhoto()\n            },\n            modifier = Modifier\n                .aspectRatio(1f)\n                .weight(1f),\n            enabled = (chatState.chatable() && !localActivity.hasImage)\n        ) {\n            Icon(\n                imageVector = Icons.Filled.AddAPhoto,\n                contentDescription = \"use camera\",\n            )\n        }\n        IconButton(\n            onClick = {\n                activity.pickImageFromGallery()\n            },\n            modifier = Modifier\n                .aspectRatio(1f)\n                .weight(1f),\n            enabled = (chatState.chatable() && !localActivity.hasImage)\n        ) {\n            Icon(\n                imageVector = Icons.Filled.Photo,\n                contentDescription = \"select image\",\n            )\n        }\n        IconButton(\n            onClick = {\n                localFocusManager.clearFocus()\n                chatState.requestGenerate(text, activity)\n                text = \"\"\n            },\n            modifier = Modifier\n                .aspectRatio(1f)\n                .weight(1f),\n            enabled = (text != \"\" && chatState.chatable())\n        ) {\n            Icon(\n                imageVector = Icons.Filled.Send,\n                contentDescription = \"send message\",\n            )\n        }\n    }\n}\n\n@Preview\n@Composable\nfun MessageViewPreviewWithMarkdown() {\n    MessageView(\n        messageData = MessageData(\n            role = MessageRole.Assistant, text = \"\"\"\n# Sample  Header\n* Markdown\n* [Link](https://example.com)\n<a href=\"https://www.google.com/\">Google</a>\n\"\"\"\n        ), null\n    )\n}\n"
  },
  {
    "path": "android/MLCChat/app/src/main/java/ai/mlc/mlcchat/MainActivity.kt",
    "content": "package ai.mlc.mlcchat\n\nimport android.Manifest\nimport android.content.ContentValues\nimport android.content.pm.PackageManager\nimport android.net.Uri\nimport android.os.Build\nimport android.os.Bundle\nimport android.provider.MediaStore\nimport android.util.Log\nimport androidx.activity.ComponentActivity\nimport androidx.activity.compose.setContent\nimport androidx.activity.result.contract.ActivityResultContracts\nimport androidx.annotation.RequiresApi\nimport androidx.compose.foundation.layout.fillMaxSize\nimport androidx.compose.material3.ExperimentalMaterial3Api\nimport androidx.compose.material3.Surface\nimport androidx.compose.ui.Modifier\nimport androidx.core.content.ContextCompat\nimport ai.mlc.mlcchat.ui.theme.MLCChatTheme\nimport java.text.SimpleDateFormat\nimport java.util.Date\nimport java.util.Locale\nimport java.util.UUID\n\nclass MainActivity : ComponentActivity() {\n    var hasImage = false\n\n    private val pickImageLauncher = registerForActivityResult(\n        ActivityResultContracts.GetContent()\n    ) { uri: Uri? ->\n        uri?.let {\n            Log.v(\"pickImageLauncher\", \"Selected image uri: $it\")\n            chatState.messages.add(\n                MessageData(\n                    role = MessageRole.User,\n                    text = \"\",\n                    id = UUID.randomUUID(),\n                    imageUri = it\n                )\n            )\n        }\n    }\n\n    private var cameraImageUri: Uri? = null\n    private val takePictureLauncher = registerForActivityResult(\n        ActivityResultContracts.TakePicture()\n    ) { success: Boolean ->\n        if (success && cameraImageUri != null) {\n            Log.v(\"takePictureLauncher\", \"Camera image uri: $cameraImageUri\")\n            chatState.messages.add(\n                MessageData(\n                    role = MessageRole.User,\n                    text = \"\",\n                    id = UUID.randomUUID(),\n                    imageUri = cameraImageUri\n                )\n            )\n        }\n    }\n\n    private val requestPermissionLauncher =\n        registerForActivityResult(ActivityResultContracts.RequestMultiplePermissions()) { permissions ->\n            permissions.entries.forEach {\n                Log.d(\"Permissions\", \"${it.key} = ${it.value}\")\n            }\n        }\n\n    lateinit var chatState: AppViewModel.ChatState\n\n    @RequiresApi(Build.VERSION_CODES.TIRAMISU)\n    @ExperimentalMaterial3Api\n    override fun onCreate(savedInstanceState: Bundle?) {\n        super.onCreate(savedInstanceState)\n\n        chatState = AppViewModel(this.application).ChatState()\n        requestNeededPermissions()\n\n        setContent {\n            Surface(\n                modifier = Modifier.fillMaxSize()\n            ) {\n                MLCChatTheme {\n                    NavView(this)\n                }\n            }\n        }\n    }\n\n    private fun requestNeededPermissions() {\n        val permissionsToRequest = mutableListOf<String>()\n\n        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {\n            if (ContextCompat.checkSelfPermission(\n                    this,\n                    Manifest.permission.READ_MEDIA_IMAGES\n                ) != PackageManager.PERMISSION_GRANTED\n            ) {\n                permissionsToRequest.add(Manifest.permission.READ_MEDIA_IMAGES)\n            }\n            if (ContextCompat.checkSelfPermission(\n                    this,\n                    Manifest.permission.CAMERA\n                ) != PackageManager.PERMISSION_GRANTED\n            ) {\n                permissionsToRequest.add(Manifest.permission.CAMERA)\n            }\n        } else {\n            if (ContextCompat.checkSelfPermission(\n                    this,\n                    Manifest.permission.READ_EXTERNAL_STORAGE\n                ) != PackageManager.PERMISSION_GRANTED\n            ) {\n                permissionsToRequest.add(Manifest.permission.READ_EXTERNAL_STORAGE)\n            }\n            if (ContextCompat.checkSelfPermission(\n                    this,\n                    Manifest.permission.WRITE_EXTERNAL_STORAGE\n                ) != PackageManager.PERMISSION_GRANTED\n            ) {\n                permissionsToRequest.add(Manifest.permission.WRITE_EXTERNAL_STORAGE)\n            }\n            if (ContextCompat.checkSelfPermission(\n                    this,\n                    Manifest.permission.CAMERA\n                ) != PackageManager.PERMISSION_GRANTED\n            ) {\n                permissionsToRequest.add(Manifest.permission.CAMERA)\n            }\n        }\n\n        if (permissionsToRequest.isNotEmpty()) {\n            requestPermissionLauncher.launch(permissionsToRequest.toTypedArray())\n        }\n    }\n\n    fun pickImageFromGallery() {\n        pickImageLauncher.launch(\"image/*\")\n    }\n\n    fun takePhoto() {\n        val contentValues = ContentValues().apply {\n            val timeFormatter = SimpleDateFormat(\"yyyyMMdd_HHmmss\", Locale.getDefault())\n            val fileName = \"IMG_${timeFormatter.format(Date())}.jpg\"\n            put(MediaStore.Images.Media.DISPLAY_NAME, fileName)\n            put(MediaStore.Images.Media.MIME_TYPE, \"image/jpeg\")\n            put(MediaStore.Images.Media.DATE_ADDED, System.currentTimeMillis() / 1000)\n        }\n\n        cameraImageUri = contentResolver.insert(\n            MediaStore.Images.Media.EXTERNAL_CONTENT_URI,\n            contentValues\n        )\n\n        takePictureLauncher.launch(cameraImageUri)\n    }\n}\n"
  },
  {
    "path": "android/MLCChat/app/src/main/java/ai/mlc/mlcchat/NavView.kt",
    "content": "package ai.mlc.mlcchat\n\nimport android.app.Activity\nimport androidx.compose.material3.ExperimentalMaterial3Api\nimport androidx.compose.runtime.Composable\nimport androidx.lifecycle.viewmodel.compose.viewModel\nimport androidx.navigation.compose.NavHost\nimport androidx.navigation.compose.composable\nimport androidx.navigation.compose.rememberNavController\n\n@ExperimentalMaterial3Api\n@Composable\nfun NavView(activity: Activity, appViewModel: AppViewModel = viewModel()) {\n    val navController = rememberNavController()\n    NavHost(navController = navController, startDestination = \"home\") {\n        composable(\"home\") { StartView(navController, appViewModel) }\n        composable(\"chat\") { ChatView(navController, appViewModel.chatState, activity) }\n    }\n}\n"
  },
  {
    "path": "android/MLCChat/app/src/main/java/ai/mlc/mlcchat/StartView.kt",
    "content": "package ai.mlc.mlcchat\n\nimport androidx.compose.foundation.gestures.detectTapGestures\nimport androidx.compose.foundation.layout.Arrangement\nimport androidx.compose.foundation.layout.Column\nimport androidx.compose.foundation.layout.Row\nimport androidx.compose.foundation.layout.aspectRatio\nimport androidx.compose.foundation.layout.fillMaxSize\nimport androidx.compose.foundation.layout.fillMaxWidth\nimport androidx.compose.foundation.layout.height\nimport androidx.compose.foundation.layout.padding\nimport androidx.compose.foundation.layout.width\nimport androidx.compose.foundation.layout.wrapContentHeight\nimport androidx.compose.foundation.lazy.LazyColumn\nimport androidx.compose.foundation.lazy.items\nimport androidx.compose.foundation.text.selection.SelectionContainer\nimport androidx.compose.material.icons.Icons\nimport androidx.compose.material.icons.outlined.Chat\nimport androidx.compose.material.icons.outlined.Delete\nimport androidx.compose.material.icons.outlined.Download\nimport androidx.compose.material.icons.outlined.Pause\nimport androidx.compose.material.icons.outlined.Schedule\nimport androidx.compose.material3.AlertDialog\nimport androidx.compose.material3.Divider\nimport androidx.compose.material3.ExperimentalMaterial3Api\nimport androidx.compose.material3.Icon\nimport androidx.compose.material3.IconButton\nimport androidx.compose.material3.LinearProgressIndicator\nimport androidx.compose.material3.MaterialTheme\nimport androidx.compose.material3.OutlinedTextField\nimport androidx.compose.material3.Scaffold\nimport androidx.compose.material3.Text\nimport androidx.compose.material3.TextButton\nimport androidx.compose.material3.TopAppBar\nimport androidx.compose.material3.TopAppBarDefaults\nimport androidx.compose.runtime.Composable\nimport androidx.compose.runtime.getValue\nimport androidx.compose.runtime.mutableStateOf\nimport androidx.compose.runtime.saveable.rememberSaveable\nimport androidx.compose.runtime.setValue\nimport androidx.compose.ui.Alignment\nimport androidx.compose.ui.Modifier\nimport androidx.compose.ui.input.pointer.pointerInput\nimport androidx.compose.ui.platform.LocalFocusManager\nimport androidx.compose.ui.text.style.TextAlign\nimport androidx.compose.ui.unit.dp\nimport androidx.navigation.NavController\n\n\n@ExperimentalMaterial3Api\n@Composable\nfun StartView(\n    navController: NavController,\n    appViewModel: AppViewModel\n) {\n    val localFocusManager = LocalFocusManager.current\n    Scaffold(\n        topBar = {\n            TopAppBar(\n                title = { Text(text = \"MLCChat\", color = MaterialTheme.colorScheme.onPrimary) },\n                colors = TopAppBarDefaults.topAppBarColors(containerColor = MaterialTheme.colorScheme.primary)\n            )\n        },\n        modifier = Modifier.pointerInput(Unit) {\n            detectTapGestures(onTap = {\n                localFocusManager.clearFocus()\n            })\n        }\n    )\n    { paddingValues ->\n        Column(\n            modifier = Modifier\n                .fillMaxSize()\n                .padding(paddingValues)\n                .padding(horizontal = 10.dp)\n        ) {\n            Text(text = \"Model List\", modifier = Modifier.padding(top = 10.dp))\n            LazyColumn() {\n                items(items = appViewModel.modelList,\n                    key = { modelState -> modelState.id }\n                ) { modelState ->\n                    ModelView(\n                        navController = navController,\n                        modelState = modelState,\n                        appViewModel = appViewModel\n                    )\n                }\n            }\n        }\n        if (appViewModel.isShowingAlert()) {\n            AlertDialog(\n                onDismissRequest = { appViewModel.dismissAlert() },\n                onConfirmation = { appViewModel.copyError() },\n                error = appViewModel.errorMessage()\n            )\n        }\n    }\n}\n\n@ExperimentalMaterial3Api\n@Composable\nfun AlertDialog(\n    onDismissRequest: () -> Unit,\n    onConfirmation: () -> Unit,\n    error: String,\n) {\n    AlertDialog(\n        title = { Text(text = \"Error\") },\n        text = { Text(text = error) },\n        onDismissRequest = { onDismissRequest() },\n        confirmButton = {\n            TextButton(onClick = { onConfirmation() }) { Text(\"Copy\") }\n        },\n        dismissButton = {\n            TextButton(onClick = { onDismissRequest() }) { Text(\"Dismiss\") }\n        }\n    )\n}\n\n@Composable\nfun ModelView(\n    navController: NavController,\n    modelState: AppViewModel.ModelState,\n    appViewModel: AppViewModel\n) {\n    var isDeletingModel by rememberSaveable { mutableStateOf(false) }\n    Column(\n        verticalArrangement = Arrangement.SpaceBetween,\n        modifier = Modifier\n            .wrapContentHeight()\n    ) {\n        Row(\n            horizontalArrangement = Arrangement.spacedBy(5.dp),\n            verticalAlignment = Alignment.CenterVertically,\n            modifier = Modifier\n                .fillMaxWidth()\n                .wrapContentHeight()\n        ) {\n            Text(\n                text = modelState.modelConfig.modelId,\n                textAlign = TextAlign.Left,\n                modifier = Modifier\n                    .wrapContentHeight()\n                    .weight(8f)\n            )\n            Divider(\n                modifier = Modifier\n                    .height(20.dp)\n                    .width(1.dp)\n            )\n            if (modelState.modelInitState.value == ModelInitState.Paused) {\n                IconButton(\n                    onClick = { modelState.handleStart() }, modifier = Modifier\n                        .aspectRatio(1f)\n                        .weight(1f)\n                ) {\n                    Icon(\n                        imageVector = Icons.Outlined.Download,\n                        contentDescription = \"start downloading\",\n                    )\n                }\n\n            } else if (modelState.modelInitState.value == ModelInitState.Downloading) {\n                IconButton(\n                    onClick = { modelState.handlePause() }, modifier = Modifier\n                        .aspectRatio(1f)\n                        .weight(1f)\n                ) {\n                    Icon(\n                        imageVector = Icons.Outlined.Pause,\n                        contentDescription = \"pause downloading\",\n                    )\n                }\n            } else if (modelState.modelInitState.value == ModelInitState.Finished) {\n                IconButton(\n                    onClick = {\n                        modelState.startChat()\n                        navController.navigate(\"chat\")\n                    },\n                    enabled = appViewModel.chatState.interruptable(),\n                    modifier = Modifier\n                        .aspectRatio(1f)\n                        .weight(1f)\n                ) {\n                    Icon(\n                        imageVector = Icons.Outlined.Chat,\n                        contentDescription = \"start chatting\",\n                    )\n                }\n            } else {\n                IconButton(\n                    enabled = false, onClick = {}, modifier = Modifier\n                        .aspectRatio(1f)\n                        .weight(1f)\n                ) {\n                    Icon(\n                        imageVector = Icons.Outlined.Schedule,\n                        contentDescription = \"pending\",\n                    )\n                }\n            }\n            if (modelState.modelInitState.value == ModelInitState.Downloading ||\n                modelState.modelInitState.value == ModelInitState.Paused ||\n                modelState.modelInitState.value == ModelInitState.Finished\n            ) {\n                IconButton(\n                    onClick = { isDeletingModel = true },\n                    modifier = Modifier\n                        .aspectRatio(1f)\n                        .weight(1f)\n                ) {\n                    Icon(\n                        imageVector = Icons.Outlined.Delete,\n                        contentDescription = \"start downloading\",\n                        tint = MaterialTheme.colorScheme.error\n                    )\n                }\n            }\n        }\n        LinearProgressIndicator(\n            progress = modelState.progress.value.toFloat() / modelState.total.value,\n            modifier = Modifier.fillMaxWidth()\n        )\n        if (isDeletingModel) {\n            Row(\n                horizontalArrangement = Arrangement.End,\n                verticalAlignment = Alignment.CenterVertically,\n                modifier = Modifier\n                    .fillMaxWidth()\n                    .wrapContentHeight()\n            ) {\n                TextButton(onClick = { isDeletingModel = false }) {\n                    Text(text = \"cancel\")\n                }\n                TextButton(onClick = {\n                    isDeletingModel = false\n                    modelState.handleClear()\n                }) {\n                    Text(text = \"clear data\", color = MaterialTheme.colorScheme.error)\n                }\n                TextButton(onClick = {\n                    isDeletingModel = false\n                    modelState.handleDelete()\n                }) {\n                    Text(text = \"delete model\", color = MaterialTheme.colorScheme.error)\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ui/theme/Color.kt",
    "content": "package ai.mlc.mlcchat.ui.theme\n\nimport androidx.compose.ui.graphics.Color\n\nval Blue10 = Color(0xFF000F5E)\nval Blue20 = Color(0xFF001E92)\nval Blue30 = Color(0xFF002ECC)\nval Blue40 = Color(0xFF1546F6)\nval Blue80 = Color(0xFFB8C3FF)\nval Blue90 = Color(0xFFDDE1FF)\n\nval DarkBlue10 = Color(0xFF00036B)\nval DarkBlue20 = Color(0xFF000BA6)\nval DarkBlue30 = Color(0xFF1026D3)\nval DarkBlue40 = Color(0xFF3648EA)\nval DarkBlue80 = Color(0xFFBBC2FF)\nval DarkBlue90 = Color(0xFFDEE0FF)\n\nval Yellow10 = Color(0xFF261900)\nval Yellow20 = Color(0xFF402D00)\nval Yellow30 = Color(0xFF5C4200)\nval Yellow40 = Color(0xFF7A5900)\nval Yellow80 = Color(0xFFFABD1B)\nval Yellow90 = Color(0xFFFFDE9C)\n\nval Red10 = Color(0xFF410001)\nval Red20 = Color(0xFF680003)\nval Red30 = Color(0xFF930006)\nval Red40 = Color(0xFFBA1B1B)\nval Red80 = Color(0xFFFFB4A9)\nval Red90 = Color(0xFFFFDAD4)\n\nval Grey10 = Color(0xFF191C1D)\nval Grey20 = Color(0xFF2D3132)\nval Grey80 = Color(0xFFC4C7C7)\nval Grey90 = Color(0xFFE0E3E3)\nval Grey95 = Color(0xFFEFF1F1)\nval Grey99 = Color(0xFFFBFDFD)\n\nval BlueGrey30 = Color(0xFF45464F)\nval BlueGrey50 = Color(0xFF767680)\nval BlueGrey60 = Color(0xFF90909A)\nval BlueGrey80 = Color(0xFFC6C5D0)\nval BlueGrey90 = Color(0xFFE2E1EC)\n"
  },
  {
    "path": "android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ui/theme/Theme.kt",
    "content": "package ai.mlc.mlcchat.ui.theme\n\nimport android.app.Activity\nimport android.os.Build\nimport androidx.compose.foundation.isSystemInDarkTheme\nimport androidx.compose.material3.MaterialTheme\nimport androidx.compose.material3.darkColorScheme\nimport androidx.compose.material3.dynamicDarkColorScheme\nimport androidx.compose.material3.dynamicLightColorScheme\nimport androidx.compose.material3.lightColorScheme\nimport androidx.compose.runtime.Composable\nimport androidx.compose.runtime.SideEffect\nimport androidx.compose.ui.graphics.Color\nimport androidx.compose.ui.graphics.toArgb\nimport androidx.compose.ui.platform.LocalContext\nimport androidx.compose.ui.platform.LocalView\nimport androidx.core.view.WindowCompat\n\nprivate val DarkColorScheme = darkColorScheme(\n    primary = Blue80,\n    onPrimary = Blue20,\n    primaryContainer = Blue30,\n    onPrimaryContainer = Blue90,\n    inversePrimary = Blue40,\n    secondary = DarkBlue80,\n    onSecondary = DarkBlue20,\n    secondaryContainer = DarkBlue30,\n    onSecondaryContainer = DarkBlue90,\n    tertiary = Yellow80,\n    onTertiary = Yellow20,\n    tertiaryContainer = Yellow30,\n    onTertiaryContainer = Yellow90,\n    error = Red80,\n    onError = Red20,\n    errorContainer = Red30,\n    onErrorContainer = Red90,\n    background = Grey10,\n    onBackground = Grey90,\n    surface = Grey10,\n    onSurface = Grey80,\n    inverseSurface = Grey90,\n    inverseOnSurface = Grey20,\n    surfaceVariant = BlueGrey30,\n    onSurfaceVariant = BlueGrey80,\n    outline = BlueGrey60\n)\n\nprivate val LightColorScheme = lightColorScheme(\n    primary = Blue40,\n    onPrimary = Color.White,\n    primaryContainer = Blue90,\n    onPrimaryContainer = Blue10,\n    inversePrimary = Blue80,\n    secondary = DarkBlue40,\n    onSecondary = Color.White,\n    secondaryContainer = DarkBlue90,\n    onSecondaryContainer = DarkBlue10,\n    tertiary = Yellow40,\n    onTertiary = Color.White,\n    tertiaryContainer = Yellow90,\n    onTertiaryContainer = Yellow10,\n    error = Red40,\n    onError = Color.White,\n    errorContainer = Red90,\n    onErrorContainer = Red10,\n    background = Grey99,\n    onBackground = Grey10,\n    surface = Grey99,\n    onSurface = Grey10,\n    inverseSurface = Grey20,\n    inverseOnSurface = Grey95,\n    surfaceVariant = BlueGrey90,\n    onSurfaceVariant = BlueGrey30,\n    outline = BlueGrey50\n)\n\n@Composable\nfun MLCChatTheme(\n    darkTheme: Boolean = isSystemInDarkTheme(),\n    // Dynamic color is available on Android 12+\n    dynamicColor: Boolean = true,\n    content: @Composable () -> Unit\n) {\n    val colorScheme = when {\n        dynamicColor && Build.VERSION.SDK_INT >= Build.VERSION_CODES.S -> {\n            val context = LocalContext.current\n            if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context)\n        }\n\n        darkTheme -> DarkColorScheme\n        else -> LightColorScheme\n    }\n    val view = LocalView.current\n    if (!view.isInEditMode) {\n        SideEffect {\n            val window = (view.context as Activity).window\n            window.statusBarColor = colorScheme.primary.toArgb()\n            WindowCompat.getInsetsController(window, view).isAppearanceLightStatusBars = darkTheme\n        }\n    }\n\n    MaterialTheme(\n        colorScheme = colorScheme,\n        typography = Typography,\n        content = content\n    )\n}\n"
  },
  {
    "path": "android/MLCChat/app/src/main/java/ai/mlc/mlcchat/ui/theme/Type.kt",
    "content": "package ai.mlc.mlcchat.ui.theme\n\nimport androidx.compose.material3.Typography\nimport androidx.compose.ui.text.TextStyle\nimport androidx.compose.ui.text.font.FontFamily\nimport androidx.compose.ui.text.font.FontWeight\nimport androidx.compose.ui.unit.sp\n\n// Set of Material typography styles to start with\nval Typography = Typography(\n    bodyLarge = TextStyle(\n        fontFamily = FontFamily.Default,\n        fontWeight = FontWeight.Normal,\n        fontSize = 16.sp,\n        lineHeight = 24.sp,\n        letterSpacing = 0.5.sp\n    )\n    /* Other default text styles to override\n    titleLarge = TextStyle(\n        fontFamily = FontFamily.Default,\n        fontWeight = FontWeight.Normal,\n        fontSize = 22.sp,\n        lineHeight = 28.sp,\n        letterSpacing = 0.sp\n    ),\n    labelSmall = TextStyle(\n        fontFamily = FontFamily.Default,\n        fontWeight = FontWeight.Medium,\n        fontSize = 11.sp,\n        lineHeight = 16.sp,\n        letterSpacing = 0.5.sp\n    )\n    */\n)\n"
  },
  {
    "path": "android/MLCChat/app/src/main/res/drawable/ic_android_black_24dp.xml",
    "content": "<vector android:height=\"24dp\" android:tint=\"#000000\"\n    android:viewportHeight=\"24\" android:viewportWidth=\"24\"\n    android:width=\"24dp\" xmlns:android=\"http://schemas.android.com/apk/res/android\">\n    <path android:fillColor=\"#FF000000\" android:pathData=\"M17.6,11.48 L19.44,8.3a0.63,0.63 0,0 0,-1.09 -0.63l-1.88,3.24a11.43,11.43 0,0 0,-8.94 0L5.65,7.67a0.63,0.63 0,0 0,-1.09 0.63L6.4,11.48A10.81,10.81 0,0 0,1 20L23,20A10.81,10.81 0,0 0,17.6 11.48ZM7,17.25A1.25,1.25 0,1 1,8.25 16,1.25 1.25,0 0,1 7,17.25ZM17,17.25A1.25,1.25 0,1 1,18.25 16,1.25 1.25,0 0,1 17,17.25Z\"/>\n</vector>\n"
  },
  {
    "path": "android/MLCChat/app/src/main/res/drawable/mlc_logo_108.xml",
    "content": "<vector xmlns:android=\"http://schemas.android.com/apk/res/android\"\n    android:width=\"108dp\"\n    android:height=\"108dp\"\n    android:viewportWidth=\"108\"\n    android:viewportHeight=\"108\">\n  <path\n      android:pathData=\"M100.93,47.91L58.41,47.91C57,47.91 55.82,49.05 55.82,50.5L55.82,69.17C57.54,68.98 59.14,69.2 60.55,70.04L60.55,52.72L98.79,52.72L98.79,103.09L60.55,103.09L60.55,87.29C59.48,88.09 58.26,88.75 57.08,89.28C56.7,89.47 56.27,89.66 55.82,89.81L55.82,105.23C55.82,106.64 56.96,107.82 58.41,107.82L100.93,107.82C102.34,107.82 103.52,106.68 103.52,105.23L103.52,50.5C103.52,49.09 102.34,47.91 100.93,47.91ZM55.93,86.72C52.88,88.13 47.57,89.39 44.29,90.12C40.63,90.92 30.02,93.36 29.1,87.6C28.34,82.87 40.82,77.6 44.02,76.23C46.96,74.97 49.94,73.79 52.92,72.64C56.12,71.42 59.56,70.88 61.16,74.93C61.85,76.68 62.04,78.14 62.07,80L62.07,80.2L62.04,80.39C61.66,83.55 58.57,85.5 55.93,86.72ZM66.58,35.01C68.03,34.9 69.29,35.96 69.4,37.38L69.82,42.3C69.94,43.75 68.87,45.01 67.46,45.13C66.01,45.24 64.75,44.17 64.63,42.76L64.21,37.84C64.06,36.42 65.13,35.16 66.58,35.01ZM85.55,45.96C85.55,43.03 85.39,40.2 85.13,37.53L85.13,37.57C85.01,36.65 84.9,35.74 84.78,34.82L84.78,34.75L84.75,34.59C84.25,31.23 83.6,27.91 82.8,24.55C82,21.2 79.1,19.1 75.7,19.02C69.52,18.87 63.3,18.79 57.11,19.02C56.05,17.07 52.88,15.86 49.18,16.12L44.9,6.5C45.7,5.78 46.13,4.67 46.05,3.53C45.86,1.5 44.1,0.02 42.08,0.21C40.05,0.4 38.57,2.15 38.76,4.18C38.91,6.09 40.52,7.5 42.38,7.5L46.43,16.58C43.76,17.3 41.77,18.79 41.24,20.47C35.4,21.31 29.64,22.46 23.88,23.6C23.19,23.75 22.54,23.95 21.93,24.25C15.44,25.81 8.76,29.36 8.76,29.36C8.69,30.2 8.61,31.04 8.54,31.92C8.84,31.84 9.18,31.8 9.53,31.77C14.6,31.31 19.22,36.54 19.79,43.41C20.4,50.28 16.78,56.23 11.7,56.65C11.32,56.69 10.94,56.69 10.55,56.65C10.79,57.57 11.02,58.44 11.24,59.32C15.48,61.61 21.2,63.18 24.75,63.94C25.59,64.24 26.47,64.43 27.43,64.43C36.13,64.63 44.82,64.74 53.53,63.98L53.87,63.94L53.87,57.57C47.99,58.06 41.96,58.1 32.92,57.91C31.2,57.87 29.94,56.84 29.52,55.35C27.5,48.02 27,40.43 27.46,32.23C27.54,30.66 28.64,29.44 30.36,29.1C32,28.75 33.57,28.45 35.02,28.18C35.71,28.07 37.5,27.72 38.91,27.46L38.95,27.46C40.02,27.27 41.05,27.04 42.12,26.88C45.44,27.46 47.73,32.3 52.69,31.5C57.69,31.43 59.1,26.23 62.3,25.09C64.63,25.09 66.92,25.13 69.25,25.24C70.36,25.24 71.43,25.28 73.14,25.32C74.86,25.36 76.16,26.39 76.54,27.88C76.92,29.52 77.27,31.12 77.57,32.72C78.18,38.22 78.45,42.64 78.41,46.04L85.55,46.04ZM9.79,38.06C11.78,37.88 13.65,40.58 13.95,44.09C14.26,47.61 12.92,50.58 10.94,50.73C9.98,50.81 9.03,50.24 8.3,49.17C8.72,49.51 9.18,49.66 9.64,49.63C11.2,49.48 12.27,47.07 12.01,44.25C11.74,41.42 10.29,39.25 8.72,39.36C8.27,39.4 7.85,39.63 7.5,40.05C8,38.9 8.8,38.18 9.79,38.06ZM52.65,21.88C54.29,21.88 55.59,23.22 55.59,24.82C55.59,26.46 54.25,27.76 52.65,27.76C51.01,27.76 49.71,26.43 49.71,24.82C49.67,23.22 51.01,21.88 52.65,21.88ZM42.31,37.19C43.76,37.07 45.02,38.14 45.13,39.55L45.55,44.48C45.66,45.93 44.6,47.18 43.18,47.3C41.73,47.41 40.48,46.34 40.36,44.93L39.94,40.01C39.79,38.56 40.86,37.3 42.31,37.19ZM9.75,34.06C13.5,33.71 16.97,37.95 17.43,43.52C17.92,49.09 15.29,53.86 11.51,54.17C7.77,54.51 4.3,50.28 3.84,44.7C3.34,39.17 5.98,34.4 9.75,34.06ZM53.91,100.73C49.98,99.7 46.54,97.1 45.02,92.79C47.77,92.18 51.01,91.45 53.91,90.46ZM42.84,73.79L42.19,66.46L53.91,65.85L53.91,69.47C53.26,69.63 52.61,69.86 51.96,70.08C48.95,71.23 45.93,72.45 42.96,73.71ZM29.64,73.59C33.19,71 37.61,71.04 39.52,73.67C39.83,74.09 40.02,74.51 40.17,74.97C35.82,76.91 29.83,80 27.43,83.86C27.12,83.63 26.85,83.32 26.62,83.02C24.71,80.39 26.05,76.15 29.64,73.59ZM78.68,84.28C79.36,84.13 80.09,84.13 80.77,84.28L81.58,82.91L81.92,83.02C82.61,83.29 83.25,83.63 83.83,84.13L84.09,84.36L83.33,85.77C83.56,86.04 83.79,86.3 83.94,86.61C84.13,86.91 84.25,87.22 84.36,87.56L85.96,87.56L86.04,87.94C86.16,88.67 86.16,89.43 86.04,90.16L85.96,90.5L84.36,90.54C84.13,91.23 83.79,91.84 83.33,92.33L84.13,93.71L83.87,93.93C83.56,94.16 83.25,94.39 82.95,94.58C82.64,94.77 82.3,94.93 81.96,95.04L81.61,95.16L80.77,93.78C80.09,93.93 79.36,93.93 78.68,93.78L77.88,95.16L77.53,95.04C76.84,94.77 76.2,94.43 75.63,93.93L75.36,93.71L76.12,92.29C75.89,92.03 75.66,91.76 75.51,91.45C75.32,91.15 75.2,90.84 75.09,90.5L73.48,90.5L73.41,90.12C73.3,89.39 73.3,88.63 73.41,87.91L73.48,87.56L75.09,87.52C75.32,86.84 75.66,86.23 76.12,85.73L75.32,84.36L75.59,84.13C75.89,83.9 76.2,83.67 76.5,83.48C76.8,83.29 77.15,83.13 77.49,83.02L77.84,82.91ZM64.18,57.76L94.36,57.76L94.36,61L64.18,61ZM64.18,64.97L76.39,64.97L76.39,68.21L64.18,68.21ZM64.18,72.34L74.02,72.34L74.02,75.58L64.25,75.58L64.18,75.31ZM90.09,67.49C91,67.79 91.84,68.29 92.57,68.9L94.48,67.79L94.82,68.18C95.47,68.94 96,69.86 96.34,70.81L96.54,71.27L94.67,72.41C94.78,72.87 94.82,73.36 94.82,73.86C94.82,74.36 94.78,74.82 94.67,75.27L96.57,76.38L96.38,76.84C96.04,77.79 95.51,78.67 94.86,79.47L94.55,79.85L92.64,78.79C91.92,79.43 91.08,79.93 90.16,80.23L90.16,82.41L89.67,82.48C89.17,82.56 88.64,82.64 88.14,82.64C87.64,82.64 87.15,82.6 86.65,82.52L86.16,82.45L86.12,80.23C85.2,79.93 84.36,79.43 83.64,78.82L81.73,79.93L81.39,79.55C80.74,78.79 80.2,77.87 79.86,76.91L79.67,76.46L81.54,75.31C81.43,74.85 81.39,74.36 81.39,73.86C81.39,73.36 81.43,72.91 81.54,72.45L79.63,71.34L79.82,70.85C80.16,69.89 80.7,69.02 81.35,68.21L81.65,67.83L83.56,68.9C84.29,68.25 85.13,67.75 86.04,67.45L86.04,65.31L86.54,65.23C87.04,65.16 87.57,65.08 88.06,65.08C88.56,65.08 89.05,65.12 89.55,65.2L90.05,65.27ZM88.06,70.54C86.2,70.54 84.71,72.03 84.71,73.9C84.71,75.77 86.2,77.26 88.06,77.26C89.93,77.26 91.42,75.77 91.42,73.9C91.42,72.03 89.89,70.54 88.06,70.54ZM78.48,86.95C77.3,87.64 76.92,89.13 77.61,90.31C78.29,91.49 79.78,91.88 80.96,91.19C82.15,90.5 82.53,89.01 81.84,87.83C81.16,86.68 79.67,86.26 78.48,86.95ZM78.48,86.95\"\n      android:fillColor=\"#062578\"\n      android:fillType=\"evenOdd\"\n      android:strokeColor=\"#00000000\"/>\n</vector>\n"
  },
  {
    "path": "android/MLCChat/app/src/main/res/values/colors.xml",
    "content": "<?xml version=\"1.0\" encoding=\"utf-8\"?>\n<resources>\n    <color name=\"purple_200\">#FFBB86FC</color>\n    <color name=\"purple_500\">#FF6200EE</color>\n    <color name=\"purple_700\">#FF3700B3</color>\n    <color name=\"teal_200\">#FF03DAC5</color>\n    <color name=\"teal_700\">#FF018786</color>\n    <color name=\"black\">#FF000000</color>\n    <color name=\"white\">#FFFFFFFF</color>\n</resources>\n"
  },
  {
    "path": "android/MLCChat/app/src/main/res/values/strings.xml",
    "content": "<resources>\n    <string name=\"app_name\">MLCChat</string>\n</resources>\n"
  },
  {
    "path": "android/MLCChat/app/src/main/res/values/themes.xml",
    "content": "<?xml version=\"1.0\" encoding=\"utf-8\"?>\n<resources>\n\n    <style name=\"Theme.MLCChat\" parent=\"android:Theme.Material.Light\" />\n\n</resources>\n"
  },
  {
    "path": "android/MLCChat/app/src/main/res/xml/backup_rules.xml",
    "content": "<?xml version=\"1.0\" encoding=\"utf-8\"?><!--\n   Sample backup rules file; uncomment and customize as necessary.\n   See https://developer.android.com/guide/topics/data/autobackup\n   for details.\n   Note: This file is ignored for devices older that API 31\n   See https://developer.android.com/about/versions/12/backup-restore\n-->\n<full-backup-content>\n    <!--\n   <include domain=\"sharedpref\" path=\".\"/>\n   <exclude domain=\"sharedpref\" path=\"device.xml\"/>\n-->\n</full-backup-content>\n"
  },
  {
    "path": "android/MLCChat/app/src/main/res/xml/data_extraction_rules.xml",
    "content": "<?xml version=\"1.0\" encoding=\"utf-8\"?><!--\n   Sample data extraction rules file; uncomment and customize as necessary.\n   See https://developer.android.com/about/versions/12/backup-restore#xml-changes\n   for details.\n-->\n<data-extraction-rules>\n    <cloud-backup>\n        <!-- TODO: Use <include> and <exclude> to control what is backed up.\n        <include .../>\n        <exclude .../>\n        -->\n    </cloud-backup>\n    <!--\n    <device-transfer>\n        <include .../>\n        <exclude .../>\n    </device-transfer>\n    -->\n</data-extraction-rules>\n"
  },
  {
    "path": "android/MLCChat/build.gradle",
    "content": "plugins {\n    id 'com.android.application' version '8.2.0' apply false\n    id 'com.android.library' version '8.2.0' apply false\n    id 'org.jetbrains.kotlin.android' version '1.8.10' apply false\n}\n"
  },
  {
    "path": "android/MLCChat/bundle_weight.py",
    "content": "import argparse\nimport os\nimport subprocess\nfrom pathlib import Path\n\nfrom mlc_llm.support import logging\n\nlogging.enable_logging()\nlogger = logging.getLogger(__name__)\n\n\ndef main(apk_path: Path, package_output_path: Path):\n    \"\"\"Push weights to the android device with adb\"\"\"\n    # - Install the apk on device.\n    logger.info('Install apk \"%s\" to device', str(apk_path.absolute()))\n    subprocess.run([\"adb\", \"install\", str(apk_path)], check=True, env=os.environ)\n    # - Create the weight directory for the app.\n    device_weihgt_dir = \"/storage/emulated/0/Android/data/ai.mlc.mlcchat/files/\"\n    logger.info('Creating directory \"%s\" on device', device_weihgt_dir)\n    subprocess.run(\n        [\"adb\", \"shell\", \"mkdir\", \"-p\", device_weihgt_dir],\n        check=True,\n        env=os.environ,\n    )\n    for model_weight_dir in (package_output_path / \"bundle\").iterdir():\n        if model_weight_dir.is_dir():\n            src_path = str(model_weight_dir.absolute())\n            dst_path = \"/data/local/tmp/\" + model_weight_dir.name\n            logger.info('Pushing local weights \"%s\" to device location \"%s\"', src_path, dst_path)\n            subprocess.run([\"adb\", \"push\", src_path, dst_path], check=True, env=os.environ)\n\n            src_path = dst_path\n            dst_path = \"/storage/emulated/0/Android/data/ai.mlc.mlcchat/files/\"\n            logger.info('Move weights from \"%s\" to \"%s\"', src_path, dst_path)\n            subprocess.run([\"adb\", \"shell\", \"mv\", src_path, dst_path], check=True, env=os.environ)\n    logger.info(\"All finished.\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"MLC LLM Android Weight Bundle\")\n\n    def _parse_apk_path(path: str) -> Path:\n        path = Path(path)\n        if not path.exists():\n            raise ValueError(\n                f\"Path {str(path)} is expected to be an apk file, but the file does not exist.\"\n            )\n        if not path.is_file():\n            raise ValueError(f\"Path {str(path)} is expected to be an apk file.\")\n        return path\n\n    parser.add_argument(\n        \"--apk-path\",\n        type=_parse_apk_path,\n        default=\"app/release/app-release.apk\",\n        help=\"The path to generated MLCChat apk file.\",\n    )\n    parser.add_argument(\n        \"--package-output-path\",\n        type=Path,\n        default=\"dist\",\n        help='The path to the output directory of \"mlc_llm package\".',\n    )\n    args = parser.parse_args()\n    main(args.apk_path, args.package_output_path)\n"
  },
  {
    "path": "android/MLCChat/gradle/wrapper/gradle-wrapper.properties",
    "content": "#Thu Jan 25 10:19:50 EST 2024\ndistributionBase=GRADLE_USER_HOME\ndistributionPath=wrapper/dists\ndistributionUrl=https\\://services.gradle.org/distributions/gradle-8.5-bin.zip\nzipStoreBase=GRADLE_USER_HOME\nzipStorePath=wrapper/dists\n"
  },
  {
    "path": "android/MLCChat/gradle.properties",
    "content": "# Project-wide Gradle settings.\n# IDE (e.g. Android Studio) users:\n# Gradle settings configured through the IDE *will override*\n# any settings specified in this file.\n# For more details on how to configure your build environment visit\n# http://www.gradle.org/docs/current/userguide/build_environment.html\n# Specifies the JVM arguments used for the daemon process.\n# The setting is particularly useful for tweaking memory settings.\norg.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8\n# When configured, Gradle will run in incubating parallel mode.\n# This option should only be used with decoupled projects. More details, visit\n# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects\n# org.gradle.parallel=true\n# AndroidX package structure to make it clearer which packages are bundled with the\n# Android operating system, and which are packaged with your app's APK\n# https://developer.android.com/topic/libraries/support-library/androidx-rn\nandroid.useAndroidX=true\n# Kotlin code style for this project: \"official\" or \"obsolete\":\nkotlin.code.style=official\n# Enables namespacing of each library's R class so that its R class includes only the\n# resources declared in the library itself and none from the library's dependencies,\n# thereby reducing the size of the R class for that library\nandroid.nonTransitiveRClass=true\n"
  },
  {
    "path": "android/MLCChat/gradlew",
    "content": "#!/usr/bin/env sh\n\n#\n# Copyright 2015 the original author or authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n##############################################################################\n##\n##  Gradle start up script for UN*X\n##\n##############################################################################\n\n# Attempt to set APP_HOME\n# Resolve links: $0 may be a link\nPRG=\"$0\"\n# Need this for relative symlinks.\nwhile [ -h \"$PRG\" ] ; do\n    ls=`ls -ld \"$PRG\"`\n    link=`expr \"$ls\" : '.*-> \\(.*\\)$'`\n    if expr \"$link\" : '/.*' > /dev/null; then\n        PRG=\"$link\"\n    else\n        PRG=`dirname \"$PRG\"`\"/$link\"\n    fi\ndone\nSAVED=\"`pwd`\"\ncd \"`dirname \\\"$PRG\\\"`/\" >/dev/null\nAPP_HOME=\"`pwd -P`\"\ncd \"$SAVED\" >/dev/null\n\nAPP_NAME=\"Gradle\"\nAPP_BASE_NAME=`basename \"$0\"`\n\n# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.\nDEFAULT_JVM_OPTS='\"-Xmx64m\" \"-Xms64m\"'\n\n# Use the maximum available, or set MAX_FD != -1 to use that value.\nMAX_FD=\"maximum\"\n\nwarn () {\n    echo \"$*\"\n}\n\ndie () {\n    echo\n    echo \"$*\"\n    echo\n    exit 1\n}\n\n# OS specific support (must be 'true' or 'false').\ncygwin=false\nmsys=false\ndarwin=false\nnonstop=false\ncase \"`uname`\" in\n  CYGWIN* )\n    cygwin=true\n    ;;\n  Darwin* )\n    darwin=true\n    ;;\n  MINGW* )\n    msys=true\n    ;;\n  NONSTOP* )\n    nonstop=true\n    ;;\nesac\n\nCLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar\n\n\n# Determine the Java command to use to start the JVM.\nif [ -n \"$JAVA_HOME\" ] ; then\n    if [ -x \"$JAVA_HOME/jre/sh/java\" ] ; then\n        # IBM's JDK on AIX uses strange locations for the executables\n        JAVACMD=\"$JAVA_HOME/jre/sh/java\"\n    else\n        JAVACMD=\"$JAVA_HOME/bin/java\"\n    fi\n    if [ ! -x \"$JAVACMD\" ] ; then\n        die \"ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME\n\nPlease set the JAVA_HOME variable in your environment to match the\nlocation of your Java installation.\"\n    fi\nelse\n    JAVACMD=\"java\"\n    which java >/dev/null 2>&1 || die \"ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.\n\nPlease set the JAVA_HOME variable in your environment to match the\nlocation of your Java installation.\"\nfi\n\n# Increase the maximum file descriptors if we can.\nif [ \"$cygwin\" = \"false\" -a \"$darwin\" = \"false\" -a \"$nonstop\" = \"false\" ] ; then\n    MAX_FD_LIMIT=`ulimit -H -n`\n    if [ $? -eq 0 ] ; then\n        if [ \"$MAX_FD\" = \"maximum\" -o \"$MAX_FD\" = \"max\" ] ; then\n            MAX_FD=\"$MAX_FD_LIMIT\"\n        fi\n        ulimit -n $MAX_FD\n        if [ $? -ne 0 ] ; then\n            warn \"Could not set maximum file descriptor limit: $MAX_FD\"\n        fi\n    else\n        warn \"Could not query maximum file descriptor limit: $MAX_FD_LIMIT\"\n    fi\nfi\n\n# For Darwin, add options to specify how the application appears in the dock\nif $darwin; then\n    GRADLE_OPTS=\"$GRADLE_OPTS \\\"-Xdock:name=$APP_NAME\\\" \\\"-Xdock:icon=$APP_HOME/media/gradle.icns\\\"\"\nfi\n\n# For Cygwin or MSYS, switch paths to Windows format before running java\nif [ \"$cygwin\" = \"true\" -o \"$msys\" = \"true\" ] ; then\n    APP_HOME=`cygpath --path --mixed \"$APP_HOME\"`\n    CLASSPATH=`cygpath --path --mixed \"$CLASSPATH\"`\n\n    JAVACMD=`cygpath --unix \"$JAVACMD\"`\n\n    # We build the pattern for arguments to be converted via cygpath\n    ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`\n    SEP=\"\"\n    for dir in $ROOTDIRSRAW ; do\n        ROOTDIRS=\"$ROOTDIRS$SEP$dir\"\n        SEP=\"|\"\n    done\n    OURCYGPATTERN=\"(^($ROOTDIRS))\"\n    # Add a user-defined pattern to the cygpath arguments\n    if [ \"$GRADLE_CYGPATTERN\" != \"\" ] ; then\n        OURCYGPATTERN=\"$OURCYGPATTERN|($GRADLE_CYGPATTERN)\"\n    fi\n    # Now convert the arguments - kludge to limit ourselves to /bin/sh\n    i=0\n    for arg in \"$@\" ; do\n        CHECK=`echo \"$arg\"|egrep -c \"$OURCYGPATTERN\" -`\n        CHECK2=`echo \"$arg\"|egrep -c \"^-\"`                                 ### Determine if an option\n\n        if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then                    ### Added a condition\n            eval `echo args$i`=`cygpath --path --ignore --mixed \"$arg\"`\n        else\n            eval `echo args$i`=\"\\\"$arg\\\"\"\n        fi\n        i=`expr $i + 1`\n    done\n    case $i in\n        0) set -- ;;\n        1) set -- \"$args0\" ;;\n        2) set -- \"$args0\" \"$args1\" ;;\n        3) set -- \"$args0\" \"$args1\" \"$args2\" ;;\n        4) set -- \"$args0\" \"$args1\" \"$args2\" \"$args3\" ;;\n        5) set -- \"$args0\" \"$args1\" \"$args2\" \"$args3\" \"$args4\" ;;\n        6) set -- \"$args0\" \"$args1\" \"$args2\" \"$args3\" \"$args4\" \"$args5\" ;;\n        7) set -- \"$args0\" \"$args1\" \"$args2\" \"$args3\" \"$args4\" \"$args5\" \"$args6\" ;;\n        8) set -- \"$args0\" \"$args1\" \"$args2\" \"$args3\" \"$args4\" \"$args5\" \"$args6\" \"$args7\" ;;\n        9) set -- \"$args0\" \"$args1\" \"$args2\" \"$args3\" \"$args4\" \"$args5\" \"$args6\" \"$args7\" \"$args8\" ;;\n    esac\nfi\n\n# Escape application args\nsave () {\n    for i do printf %s\\\\n \"$i\" | sed \"s/'/'\\\\\\\\''/g;1s/^/'/;\\$s/\\$/' \\\\\\\\/\" ; done\n    echo \" \"\n}\nAPP_ARGS=`save \"$@\"`\n\n# Collect all arguments for the java command, following the shell quoting and substitution rules\neval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS \"\\\"-Dorg.gradle.appname=$APP_BASE_NAME\\\"\" -classpath \"\\\"$CLASSPATH\\\"\" org.gradle.wrapper.GradleWrapperMain \"$APP_ARGS\"\n\nexec \"$JAVACMD\" \"$@\"\n"
  },
  {
    "path": "android/MLCChat/gradlew.bat",
    "content": "@rem\n@rem Copyright 2015 the original author or authors.\n@rem\n@rem Licensed under the Apache License, Version 2.0 (the \"License\");\n@rem you may not use this file except in compliance with the License.\n@rem You may obtain a copy of the License at\n@rem\n@rem      https://www.apache.org/licenses/LICENSE-2.0\n@rem\n@rem Unless required by applicable law or agreed to in writing, software\n@rem distributed under the License is distributed on an \"AS IS\" BASIS,\n@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n@rem See the License for the specific language governing permissions and\n@rem limitations under the License.\n@rem\n\n@if \"%DEBUG%\" == \"\" @echo off\n@rem ##########################################################################\n@rem\n@rem  Gradle startup script for Windows\n@rem\n@rem ##########################################################################\n\n@rem Set local scope for the variables with windows NT shell\nif \"%OS%\"==\"Windows_NT\" setlocal\n\nset DIRNAME=%~dp0\nif \"%DIRNAME%\" == \"\" set DIRNAME=.\nset APP_BASE_NAME=%~n0\nset APP_HOME=%DIRNAME%\n\n@rem Resolve any \".\" and \"..\" in APP_HOME to make it shorter.\nfor %%i in (\"%APP_HOME%\") do set APP_HOME=%%~fi\n\n@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.\nset DEFAULT_JVM_OPTS=\"-Xmx64m\" \"-Xms64m\"\n\n@rem Find java.exe\nif defined JAVA_HOME goto findJavaFromJavaHome\n\nset JAVA_EXE=java.exe\n%JAVA_EXE% -version >NUL 2>&1\nif \"%ERRORLEVEL%\" == \"0\" goto execute\n\necho.\necho ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.\necho.\necho Please set the JAVA_HOME variable in your environment to match the\necho location of your Java installation.\n\ngoto fail\n\n:findJavaFromJavaHome\nset JAVA_HOME=%JAVA_HOME:\"=%\nset JAVA_EXE=%JAVA_HOME%/bin/java.exe\n\nif exist \"%JAVA_EXE%\" goto execute\n\necho.\necho ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%\necho.\necho Please set the JAVA_HOME variable in your environment to match the\necho location of your Java installation.\n\ngoto fail\n\n:execute\n@rem Setup the command line\n\nset CLASSPATH=%APP_HOME%\\gradle\\wrapper\\gradle-wrapper.jar\n\n\n@rem Execute Gradle\n\"%JAVA_EXE%\" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% \"-Dorg.gradle.appname=%APP_BASE_NAME%\" -classpath \"%CLASSPATH%\" org.gradle.wrapper.GradleWrapperMain %*\n\n:end\n@rem End local scope for the variables with windows NT shell\nif \"%ERRORLEVEL%\"==\"0\" goto mainEnd\n\n:fail\nrem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of\nrem the _cmd.exe /c_ return code!\nif  not \"\" == \"%GRADLE_EXIT_CONSOLE%\" exit 1\nexit /b 1\n\n:mainEnd\nif \"%OS%\"==\"Windows_NT\" endlocal\n\n:omega\n"
  },
  {
    "path": "android/MLCChat/mlc-package-config.json",
    "content": "{\n    \"device\": \"android\",\n    \"model_list\": [\n        {\n            \"model\": \"HF://mlc-ai/Phi-3.5-mini-instruct-q4f16_0-MLC\",\n            \"estimated_vram_bytes\": 4250586449,\n            \"model_id\": \"Phi-3.5-mini-instruct-q4f16_0-MLC\",\n            \"overrides\": {\n                \"prefill_chunk_size\": 128\n            }\n        },\n        {\n            \"model\": \"HF://mlc-ai/Qwen3-0.6B-q0f16-MLC\",\n            \"model_id\": \"Qwen3-0.6B-q0f16-MLC\",\n            \"estimated_vram_bytes\": 3000000000,\n            \"overrides\": {\n                \"prefill_chunk_size\": 128,\n                \"context_window_size\": 2048\n            }\n        },\n        {\n            \"model\": \"HF://mlc-ai/Qwen3-1.7B-q4f16_1-MLC\",\n            \"model_id\": \"Qwen3-1.7B-q4f16_1-MLC\",\n            \"estimated_vram_bytes\": 3000000000,\n            \"overrides\": {\n                \"prefill_chunk_size\": 128,\n                \"context_window_size\": 2048\n            }\n        },\n        {\n            \"model\": \"HF://mlc-ai/gemma-2-2b-it-q4f16_1-MLC\",\n            \"model_id\": \"gemma-2-2b-it-q4f16_1-MLC\",\n            \"estimated_vram_bytes\": 3000000000\n        },\n        {\n            \"model\": \"HF://mlc-ai/Llama-3.2-3B-Instruct-q4f16_0-MLC\",\n            \"estimated_vram_bytes\": 4679979417,\n            \"model_id\": \"Llama-3.2-3B-Instruct-q4f16_0-MLC\"\n        },\n        {\n            \"model\": \"HF://mlc-ai/Mistral-7B-Instruct-v0.3-q4f16_1-MLC\",\n            \"estimated_vram_bytes\": 4115131883,\n            \"model_id\": \"Mistral-7B-Instruct-v0.3-q4f16_1-MLC\",\n            \"overrides\": {\n                \"sliding_window_size\": 768,\n                \"prefill_chunk_size\": 256\n            }\n        }\n    ]\n}\n"
  },
  {
    "path": "android/MLCChat/settings.gradle",
    "content": "pluginManagement {\n    repositories {\n        google()\n        mavenCentral()\n        gradlePluginPortal()\n    }\n}\ndependencyResolutionManagement {\n    repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)\n    repositories {\n        google()\n        mavenCentral()\n        maven { url \"https://jitpack.io\" }\n    }\n}\nrootProject.name = \"MLCChat\"\ninclude ':app'\ninclude ':mlc4j'\nproject(':mlc4j').projectDir = file('dist/lib/mlc4j')\ninclude ':mlcengineexample'\n"
  },
  {
    "path": "android/MLCEngineExample/README.md",
    "content": "# MLC-LLM Android\n\nCheckout [Documentation page](https://llm.mlc.ai/docs/deploy/android.html) for more information.\n\n- run `mlc_llm package`\n- open this `MLCEngineExample/` folder as a project in Android Studio\n"
  },
  {
    "path": "android/MLCEngineExample/app/.gitignore",
    "content": "/build\n/src/main/libs\n"
  },
  {
    "path": "android/MLCEngineExample/app/build.gradle",
    "content": "plugins {\n    id 'com.android.application'\n    id 'org.jetbrains.kotlin.android'\n}\n\nandroid {\n    namespace 'ai.mlc.mlcengineexample'\n    compileSdk 34\n\n    defaultConfig {\n        applicationId \"ai.mlc.mlcengineexample\"\n        minSdk 26\n        targetSdk 33\n        versionCode 1\n        versionName \"1.0\"\n\n        testInstrumentationRunner \"androidx.test.runner.AndroidJUnitRunner\"\n        vectorDrawables {\n            useSupportLibrary true\n        }\n    }\n\n    buildTypes {\n        release {\n            minifyEnabled false\n            proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'\n        }\n    }\n    compileOptions {\n        sourceCompatibility JavaVersion.VERSION_1_8\n        targetCompatibility JavaVersion.VERSION_1_8\n    }\n    kotlinOptions {\n        jvmTarget = '1.8'\n    }\n    buildFeatures {\n        compose true\n    }\n    composeOptions {\n        kotlinCompilerExtensionVersion '1.4.3'\n    }\n    packagingOptions {\n        resources {\n            excludes += '/META-INF/{AL2.0,LGPL2.1}'\n        }\n    }\n}\n\ndependencies {\n    implementation project(\":mlc4j\")\n    implementation 'androidx.core:core-ktx:1.10.1'\n    implementation 'androidx.lifecycle:lifecycle-runtime-ktx:2.6.1'\n    implementation 'androidx.activity:activity-compose:1.7.1'\n    implementation platform('androidx.compose:compose-bom:2022.10.00')\n    implementation 'androidx.lifecycle:lifecycle-viewmodel-compose:2.6.1'\n    implementation 'androidx.compose.ui:ui'\n    implementation 'androidx.compose.ui:ui-graphics'\n    implementation 'androidx.compose.ui:ui-tooling-preview'\n    implementation 'androidx.compose.material3:material3:1.1.0'\n    implementation 'androidx.compose.material:material-icons-extended'\n    implementation 'androidx.appcompat:appcompat:1.6.1'\n    implementation 'androidx.navigation:navigation-compose:2.5.3'\n    implementation 'com.google.code.gson:gson:2.10.1'\n    implementation fileTree(dir: 'src/main/libs', include: ['*.aar', '*.jar'], exclude: [])\n    testImplementation 'junit:junit:4.13.2'\n    androidTestImplementation 'androidx.test.ext:junit:1.1.5'\n    androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.1'\n    androidTestImplementation platform('androidx.compose:compose-bom:2022.10.00')\n    androidTestImplementation 'androidx.compose.ui:ui-test-junit4'\n    debugImplementation 'androidx.compose.ui:ui-tooling'\n    debugImplementation 'androidx.compose.ui:ui-test-manifest'\n\n}\n"
  },
  {
    "path": "android/MLCEngineExample/app/proguard-rules.pro",
    "content": "# Add project specific ProGuard rules here.\n# You can control the set of applied configuration files using the\n# proguardFiles setting in build.gradle.\n#\n# For more details, see\n#   http://developer.android.com/guide/developing/tools/proguard.html\n\n# If your project uses WebView with JS, uncomment the following\n# and specify the fully qualified class name to the JavaScript interface\n# class:\n#-keepclassmembers class fqcn.of.javascript.interface.for.webview {\n#   public *;\n#}\n\n# Uncomment this to preserve the line number information for\n# debugging stack traces.\n#-keepattributes SourceFile,LineNumberTable\n\n# If you keep the line number information, uncomment this to\n# hide the original source file name.\n#-renamesourcefileattribute SourceFile\n"
  },
  {
    "path": "android/MLCEngineExample/app/src/main/AndroidManifest.xml",
    "content": "<?xml version=\"1.0\" encoding=\"utf-8\"?>\n<manifest xmlns:android=\"http://schemas.android.com/apk/res/android\"\n    xmlns:tools=\"http://schemas.android.com/tools\"\n    package=\"ai.mlc.mlcengineexample\">\n\n    <uses-permission android:name=\"android.permission.INTERNET\" />\n    <uses-permission\n        android:name=\"android.permission.WRITE_EXTERNAL_STORAGE\"\n        android:maxSdkVersion=\"32\"\n        tools:ignore=\"ScopedStorage\" />\n\n    <application\n        android:allowBackup=\"true\"\n        android:dataExtractionRules=\"@xml/data_extraction_rules\"\n        android:fullBackupContent=\"@xml/backup_rules\"\n        android:icon=\"@drawable/mlc_logo_108\"\n        android:label=\"@string/app_name\"\n        android:roundIcon=\"@drawable/mlc_logo_108\"\n        android:supportsRtl=\"true\"\n        android:theme=\"@style/Theme.MLCEngineExample\"\n        tools:targetApi=\"31\">\n        <uses-native-library\n            android:name=\"libOpenCL.so\"\n            android:required=\"false\"/>\n\n        <uses-native-library\n            android:name=\"libOpenCL-pixel.so\"\n            android:required=\"false\" />\n        <activity\n            android:name=\".MainActivity\"\n            android:exported=\"true\"\n            android:label=\"@string/app_name\"\n            android:theme=\"@android:style/Theme.Material.NoActionBar\">\n            <intent-filter>\n                <action android:name=\"android.intent.action.MAIN\" />\n                <category android:name=\"android.intent.category.LAUNCHER\" />\n            </intent-filter>\n        </activity>\n    </application>\n\n</manifest>\n"
  },
  {
    "path": "android/MLCEngineExample/app/src/main/java/ai/mlc/mlcengineexample/MainActivity.kt",
    "content": "package ai.mlc.mlcengineexample\n\nimport ai.mlc.mlcengineexample.ui.theme.MLCEngineExampleTheme\nimport ai.mlc.mlcllm.MLCEngine\nimport ai.mlc.mlcllm.OpenAIProtocol\nimport ai.mlc.mlcllm.OpenAIProtocol.*\nimport android.annotation.SuppressLint\nimport android.os.Bundle\nimport android.util.Log\nimport androidx.activity.ComponentActivity\nimport androidx.activity.compose.setContent\nimport androidx.compose.foundation.layout.fillMaxSize\nimport androidx.compose.material3.ExperimentalMaterial3Api\nimport androidx.compose.material3.Surface\nimport androidx.compose.material3.Text\nimport androidx.compose.runtime.mutableStateOf\nimport androidx.compose.runtime.remember\nimport androidx.compose.runtime.rememberCoroutineScope\nimport androidx.compose.ui.Modifier\nimport kotlinx.coroutines.GlobalScope\nimport kotlinx.coroutines.channels.ReceiveChannel\nimport kotlinx.coroutines.launch\nimport java.io.File\n\n\nclass MainActivity : ComponentActivity() {\n    @SuppressLint(\"CoroutineCreationDuringComposition\")\n    @ExperimentalMaterial3Api\n    override fun onCreate(savedInstanceState: Bundle?) {\n        super.onCreate(savedInstanceState)\n\n        val modelName = \"phi-2-q4f16_1-MLC\"\n        var modelPath = File(application.getExternalFilesDir(\"\"), modelName).toString()\n        Log.i(\"MLC\", \"model path: $modelPath\")\n        // need to be changed to the custom system lib prefix used while compiling the model\n        val modelLib = \"phi_msft_q4f16_1_686d8979c6ebf05d142d9081f1b87162\"\n        Log.i(\"MLC\", \"engine loaded\")\n\n        setContent {\n            val responseText = remember { mutableStateOf(\"\") }\n            val coroutineScope = rememberCoroutineScope()\n            val engine = MLCEngine()\n            engine.unload()\n            engine.reload(modelPath, modelLib)\n            coroutineScope.launch {\n                var channel = engine.chat.completions.create(\n                    messages = listOf(\n                        ChatCompletionMessage(\n                            role = OpenAIProtocol.ChatCompletionRole.user,\n                            content = \"What is the meaning of life?\"\n                        )\n                    ),\n                    stream_options = OpenAIProtocol.StreamOptions(include_usage = true)\n                )\n\n\n                for (response in channel) {\n                    val finalusage = response.usage\n                    if (finalusage != null) {\n                        responseText.value += \"\\n\" + (finalusage.extra?.asTextLabel() ?: \"\")\n                    } else {\n                        if (response.choices.size > 0) {\n                            responseText.value += response.choices[0].delta.content?.asText()\n                                .orEmpty()\n                        }\n                    }\n\n                }\n            }\n\n            Surface(\n                modifier = Modifier\n                    .fillMaxSize()\n            ) {\n                MLCEngineExampleTheme {\n                    Text(text = responseText.value)\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "android/MLCEngineExample/app/src/main/java/ai/mlc/mlcengineexample/ui/theme/Color.kt",
    "content": "package ai.mlc.mlcengineexample.ui.theme\n\nimport androidx.compose.ui.graphics.Color\n\nval Blue10 = Color(0xFF000F5E)\nval Blue20 = Color(0xFF001E92)\nval Blue30 = Color(0xFF002ECC)\nval Blue40 = Color(0xFF1546F6)\nval Blue80 = Color(0xFFB8C3FF)\nval Blue90 = Color(0xFFDDE1FF)\n\nval DarkBlue10 = Color(0xFF00036B)\nval DarkBlue20 = Color(0xFF000BA6)\nval DarkBlue30 = Color(0xFF1026D3)\nval DarkBlue40 = Color(0xFF3648EA)\nval DarkBlue80 = Color(0xFFBBC2FF)\nval DarkBlue90 = Color(0xFFDEE0FF)\n\nval Yellow10 = Color(0xFF261900)\nval Yellow20 = Color(0xFF402D00)\nval Yellow30 = Color(0xFF5C4200)\nval Yellow40 = Color(0xFF7A5900)\nval Yellow80 = Color(0xFFFABD1B)\nval Yellow90 = Color(0xFFFFDE9C)\n\nval Red10 = Color(0xFF410001)\nval Red20 = Color(0xFF680003)\nval Red30 = Color(0xFF930006)\nval Red40 = Color(0xFFBA1B1B)\nval Red80 = Color(0xFFFFB4A9)\nval Red90 = Color(0xFFFFDAD4)\n\nval Grey10 = Color(0xFF191C1D)\nval Grey20 = Color(0xFF2D3132)\nval Grey80 = Color(0xFFC4C7C7)\nval Grey90 = Color(0xFFE0E3E3)\nval Grey95 = Color(0xFFEFF1F1)\nval Grey99 = Color(0xFFFBFDFD)\n\nval BlueGrey30 = Color(0xFF45464F)\nval BlueGrey50 = Color(0xFF767680)\nval BlueGrey60 = Color(0xFF90909A)\nval BlueGrey80 = Color(0xFFC6C5D0)\nval BlueGrey90 = Color(0xFFE2E1EC)\n"
  },
  {
    "path": "android/MLCEngineExample/app/src/main/java/ai/mlc/mlcengineexample/ui/theme/Theme.kt",
    "content": "package ai.mlc.mlcengineexample.ui.theme\n\nimport android.app.Activity\nimport android.os.Build\nimport androidx.compose.foundation.isSystemInDarkTheme\nimport androidx.compose.material3.MaterialTheme\nimport androidx.compose.material3.darkColorScheme\nimport androidx.compose.material3.dynamicDarkColorScheme\nimport androidx.compose.material3.dynamicLightColorScheme\nimport androidx.compose.material3.lightColorScheme\nimport androidx.compose.runtime.Composable\nimport androidx.compose.runtime.SideEffect\nimport androidx.compose.ui.graphics.Color\nimport androidx.compose.ui.graphics.toArgb\nimport androidx.compose.ui.platform.LocalContext\nimport androidx.compose.ui.platform.LocalView\nimport androidx.core.view.WindowCompat\n\nprivate val DarkColorScheme = darkColorScheme(\n    primary = Blue80,\n    onPrimary = Blue20,\n    primaryContainer = Blue30,\n    onPrimaryContainer = Blue90,\n    inversePrimary = Blue40,\n    secondary = DarkBlue80,\n    onSecondary = DarkBlue20,\n    secondaryContainer = DarkBlue30,\n    onSecondaryContainer = DarkBlue90,\n    tertiary = Yellow80,\n    onTertiary = Yellow20,\n    tertiaryContainer = Yellow30,\n    onTertiaryContainer = Yellow90,\n    error = Red80,\n    onError = Red20,\n    errorContainer = Red30,\n    onErrorContainer = Red90,\n    background = Grey10,\n    onBackground = Grey90,\n    surface = Grey10,\n    onSurface = Grey80,\n    inverseSurface = Grey90,\n    inverseOnSurface = Grey20,\n    surfaceVariant = BlueGrey30,\n    onSurfaceVariant = BlueGrey80,\n    outline = BlueGrey60\n)\n\nprivate val LightColorScheme = lightColorScheme(\n    primary = Blue40,\n    onPrimary = Color.White,\n    primaryContainer = Blue90,\n    onPrimaryContainer = Blue10,\n    inversePrimary = Blue80,\n    secondary = DarkBlue40,\n    onSecondary = Color.White,\n    secondaryContainer = DarkBlue90,\n    onSecondaryContainer = DarkBlue10,\n    tertiary = Yellow40,\n    onTertiary = Color.White,\n    tertiaryContainer = Yellow90,\n    onTertiaryContainer = Yellow10,\n    error = Red40,\n    onError = Color.White,\n    errorContainer = Red90,\n    onErrorContainer = Red10,\n    background = Grey99,\n    onBackground = Grey10,\n    surface = Grey99,\n    onSurface = Grey10,\n    inverseSurface = Grey20,\n    inverseOnSurface = Grey95,\n    surfaceVariant = BlueGrey90,\n    onSurfaceVariant = BlueGrey30,\n    outline = BlueGrey50\n)\n\n@Composable\nfun MLCEngineExampleTheme(\n    darkTheme: Boolean = isSystemInDarkTheme(),\n    // Dynamic color is available on Android 12+\n    dynamicColor: Boolean = true,\n    content: @Composable () -> Unit\n) {\n    val colorScheme = when {\n        dynamicColor && Build.VERSION.SDK_INT >= Build.VERSION_CODES.S -> {\n            val context = LocalContext.current\n            if (darkTheme) dynamicDarkColorScheme(context) else dynamicLightColorScheme(context)\n        }\n\n        darkTheme -> DarkColorScheme\n        else -> LightColorScheme\n    }\n    val view = LocalView.current\n    if (!view.isInEditMode) {\n        SideEffect {\n            val window = (view.context as Activity).window\n            window.statusBarColor = colorScheme.primary.toArgb()\n            WindowCompat.getInsetsController(window, view).isAppearanceLightStatusBars = darkTheme\n        }\n    }\n\n    MaterialTheme(\n        colorScheme = colorScheme,\n        typography = Typography,\n        content = content\n    )\n}\n"
  },
  {
    "path": "android/MLCEngineExample/app/src/main/java/ai/mlc/mlcengineexample/ui/theme/Type.kt",
    "content": "package ai.mlc.mlcengineexample.ui.theme\n\nimport androidx.compose.material3.Typography\nimport androidx.compose.ui.text.TextStyle\nimport androidx.compose.ui.text.font.FontFamily\nimport androidx.compose.ui.text.font.FontWeight\nimport androidx.compose.ui.unit.sp\n\n// Set of Material typography styles to start with\nval Typography = Typography(\n    bodyLarge = TextStyle(\n        fontFamily = FontFamily.Default,\n        fontWeight = FontWeight.Normal,\n        fontSize = 16.sp,\n        lineHeight = 24.sp,\n        letterSpacing = 0.5.sp\n    )\n    /* Other default text styles to override\n    titleLarge = TextStyle(\n        fontFamily = FontFamily.Default,\n        fontWeight = FontWeight.Normal,\n        fontSize = 22.sp,\n        lineHeight = 28.sp,\n        letterSpacing = 0.sp\n    ),\n    labelSmall = TextStyle(\n        fontFamily = FontFamily.Default,\n        fontWeight = FontWeight.Medium,\n        fontSize = 11.sp,\n        lineHeight = 16.sp,\n        letterSpacing = 0.5.sp\n    )\n    */\n)\n"
  },
  {
    "path": "android/MLCEngineExample/app/src/main/res/drawable/ic_android_black_24dp.xml",
    "content": "<vector android:height=\"24dp\" android:tint=\"#000000\"\n    android:viewportHeight=\"24\" android:viewportWidth=\"24\"\n    android:width=\"24dp\" xmlns:android=\"http://schemas.android.com/apk/res/android\">\n    <path android:fillColor=\"#FF000000\" android:pathData=\"M17.6,11.48 L19.44,8.3a0.63,0.63 0,0 0,-1.09 -0.63l-1.88,3.24a11.43,11.43 0,0 0,-8.94 0L5.65,7.67a0.63,0.63 0,0 0,-1.09 0.63L6.4,11.48A10.81,10.81 0,0 0,1 20L23,20A10.81,10.81 0,0 0,17.6 11.48ZM7,17.25A1.25,1.25 0,1 1,8.25 16,1.25 1.25,0 0,1 7,17.25ZM17,17.25A1.25,1.25 0,1 1,18.25 16,1.25 1.25,0 0,1 17,17.25Z\"/>\n</vector>\n"
  },
  {
    "path": "android/MLCEngineExample/app/src/main/res/drawable/mlc_logo_108.xml",
    "content": "<vector xmlns:android=\"http://schemas.android.com/apk/res/android\"\n    android:width=\"108dp\"\n    android:height=\"108dp\"\n    android:viewportWidth=\"108\"\n    android:viewportHeight=\"108\">\n  <path\n      android:pathData=\"M100.93,47.91L58.41,47.91C57,47.91 55.82,49.05 55.82,50.5L55.82,69.17C57.54,68.98 59.14,69.2 60.55,70.04L60.55,52.72L98.79,52.72L98.79,103.09L60.55,103.09L60.55,87.29C59.48,88.09 58.26,88.75 57.08,89.28C56.7,89.47 56.27,89.66 55.82,89.81L55.82,105.23C55.82,106.64 56.96,107.82 58.41,107.82L100.93,107.82C102.34,107.82 103.52,106.68 103.52,105.23L103.52,50.5C103.52,49.09 102.34,47.91 100.93,47.91ZM55.93,86.72C52.88,88.13 47.57,89.39 44.29,90.12C40.63,90.92 30.02,93.36 29.1,87.6C28.34,82.87 40.82,77.6 44.02,76.23C46.96,74.97 49.94,73.79 52.92,72.64C56.12,71.42 59.56,70.88 61.16,74.93C61.85,76.68 62.04,78.14 62.07,80L62.07,80.2L62.04,80.39C61.66,83.55 58.57,85.5 55.93,86.72ZM66.58,35.01C68.03,34.9 69.29,35.96 69.4,37.38L69.82,42.3C69.94,43.75 68.87,45.01 67.46,45.13C66.01,45.24 64.75,44.17 64.63,42.76L64.21,37.84C64.06,36.42 65.13,35.16 66.58,35.01ZM85.55,45.96C85.55,43.03 85.39,40.2 85.13,37.53L85.13,37.57C85.01,36.65 84.9,35.74 84.78,34.82L84.78,34.75L84.75,34.59C84.25,31.23 83.6,27.91 82.8,24.55C82,21.2 79.1,19.1 75.7,19.02C69.52,18.87 63.3,18.79 57.11,19.02C56.05,17.07 52.88,15.86 49.18,16.12L44.9,6.5C45.7,5.78 46.13,4.67 46.05,3.53C45.86,1.5 44.1,0.02 42.08,0.21C40.05,0.4 38.57,2.15 38.76,4.18C38.91,6.09 40.52,7.5 42.38,7.5L46.43,16.58C43.76,17.3 41.77,18.79 41.24,20.47C35.4,21.31 29.64,22.46 23.88,23.6C23.19,23.75 22.54,23.95 21.93,24.25C15.44,25.81 8.76,29.36 8.76,29.36C8.69,30.2 8.61,31.04 8.54,31.92C8.84,31.84 9.18,31.8 9.53,31.77C14.6,31.31 19.22,36.54 19.79,43.41C20.4,50.28 16.78,56.23 11.7,56.65C11.32,56.69 10.94,56.69 10.55,56.65C10.79,57.57 11.02,58.44 11.24,59.32C15.48,61.61 21.2,63.18 24.75,63.94C25.59,64.24 26.47,64.43 27.43,64.43C36.13,64.63 44.82,64.74 53.53,63.98L53.87,63.94L53.87,57.57C47.99,58.06 41.96,58.1 32.92,57.91C31.2,57.87 29.94,56.84 29.52,55.35C27.5,48.02 27,40.43 27.46,32.23C27.54,30.66 28.64,29.44 30.36,29.1C32,28.75 33.57,28.45 35.02,28.18C35.71,28.07 37.5,27.72 38.91,27.46L38.95,27.46C40.02,27.27 41.05,27.04 42.12,26.88C45.44,27.46 47.73,32.3 52.69,31.5C57.69,31.43 59.1,26.23 62.3,25.09C64.63,25.09 66.92,25.13 69.25,25.24C70.36,25.24 71.43,25.28 73.14,25.32C74.86,25.36 76.16,26.39 76.54,27.88C76.92,29.52 77.27,31.12 77.57,32.72C78.18,38.22 78.45,42.64 78.41,46.04L85.55,46.04ZM9.79,38.06C11.78,37.88 13.65,40.58 13.95,44.09C14.26,47.61 12.92,50.58 10.94,50.73C9.98,50.81 9.03,50.24 8.3,49.17C8.72,49.51 9.18,49.66 9.64,49.63C11.2,49.48 12.27,47.07 12.01,44.25C11.74,41.42 10.29,39.25 8.72,39.36C8.27,39.4 7.85,39.63 7.5,40.05C8,38.9 8.8,38.18 9.79,38.06ZM52.65,21.88C54.29,21.88 55.59,23.22 55.59,24.82C55.59,26.46 54.25,27.76 52.65,27.76C51.01,27.76 49.71,26.43 49.71,24.82C49.67,23.22 51.01,21.88 52.65,21.88ZM42.31,37.19C43.76,37.07 45.02,38.14 45.13,39.55L45.55,44.48C45.66,45.93 44.6,47.18 43.18,47.3C41.73,47.41 40.48,46.34 40.36,44.93L39.94,40.01C39.79,38.56 40.86,37.3 42.31,37.19ZM9.75,34.06C13.5,33.71 16.97,37.95 17.43,43.52C17.92,49.09 15.29,53.86 11.51,54.17C7.77,54.51 4.3,50.28 3.84,44.7C3.34,39.17 5.98,34.4 9.75,34.06ZM53.91,100.73C49.98,99.7 46.54,97.1 45.02,92.79C47.77,92.18 51.01,91.45 53.91,90.46ZM42.84,73.79L42.19,66.46L53.91,65.85L53.91,69.47C53.26,69.63 52.61,69.86 51.96,70.08C48.95,71.23 45.93,72.45 42.96,73.71ZM29.64,73.59C33.19,71 37.61,71.04 39.52,73.67C39.83,74.09 40.02,74.51 40.17,74.97C35.82,76.91 29.83,80 27.43,83.86C27.12,83.63 26.85,83.32 26.62,83.02C24.71,80.39 26.05,76.15 29.64,73.59ZM78.68,84.28C79.36,84.13 80.09,84.13 80.77,84.28L81.58,82.91L81.92,83.02C82.61,83.29 83.25,83.63 83.83,84.13L84.09,84.36L83.33,85.77C83.56,86.04 83.79,86.3 83.94,86.61C84.13,86.91 84.25,87.22 84.36,87.56L85.96,87.56L86.04,87.94C86.16,88.67 86.16,89.43 86.04,90.16L85.96,90.5L84.36,90.54C84.13,91.23 83.79,91.84 83.33,92.33L84.13,93.71L83.87,93.93C83.56,94.16 83.25,94.39 82.95,94.58C82.64,94.77 82.3,94.93 81.96,95.04L81.61,95.16L80.77,93.78C80.09,93.93 79.36,93.93 78.68,93.78L77.88,95.16L77.53,95.04C76.84,94.77 76.2,94.43 75.63,93.93L75.36,93.71L76.12,92.29C75.89,92.03 75.66,91.76 75.51,91.45C75.32,91.15 75.2,90.84 75.09,90.5L73.48,90.5L73.41,90.12C73.3,89.39 73.3,88.63 73.41,87.91L73.48,87.56L75.09,87.52C75.32,86.84 75.66,86.23 76.12,85.73L75.32,84.36L75.59,84.13C75.89,83.9 76.2,83.67 76.5,83.48C76.8,83.29 77.15,83.13 77.49,83.02L77.84,82.91ZM64.18,57.76L94.36,57.76L94.36,61L64.18,61ZM64.18,64.97L76.39,64.97L76.39,68.21L64.18,68.21ZM64.18,72.34L74.02,72.34L74.02,75.58L64.25,75.58L64.18,75.31ZM90.09,67.49C91,67.79 91.84,68.29 92.57,68.9L94.48,67.79L94.82,68.18C95.47,68.94 96,69.86 96.34,70.81L96.54,71.27L94.67,72.41C94.78,72.87 94.82,73.36 94.82,73.86C94.82,74.36 94.78,74.82 94.67,75.27L96.57,76.38L96.38,76.84C96.04,77.79 95.51,78.67 94.86,79.47L94.55,79.85L92.64,78.79C91.92,79.43 91.08,79.93 90.16,80.23L90.16,82.41L89.67,82.48C89.17,82.56 88.64,82.64 88.14,82.64C87.64,82.64 87.15,82.6 86.65,82.52L86.16,82.45L86.12,80.23C85.2,79.93 84.36,79.43 83.64,78.82L81.73,79.93L81.39,79.55C80.74,78.79 80.2,77.87 79.86,76.91L79.67,76.46L81.54,75.31C81.43,74.85 81.39,74.36 81.39,73.86C81.39,73.36 81.43,72.91 81.54,72.45L79.63,71.34L79.82,70.85C80.16,69.89 80.7,69.02 81.35,68.21L81.65,67.83L83.56,68.9C84.29,68.25 85.13,67.75 86.04,67.45L86.04,65.31L86.54,65.23C87.04,65.16 87.57,65.08 88.06,65.08C88.56,65.08 89.05,65.12 89.55,65.2L90.05,65.27ZM88.06,70.54C86.2,70.54 84.71,72.03 84.71,73.9C84.71,75.77 86.2,77.26 88.06,77.26C89.93,77.26 91.42,75.77 91.42,73.9C91.42,72.03 89.89,70.54 88.06,70.54ZM78.48,86.95C77.3,87.64 76.92,89.13 77.61,90.31C78.29,91.49 79.78,91.88 80.96,91.19C82.15,90.5 82.53,89.01 81.84,87.83C81.16,86.68 79.67,86.26 78.48,86.95ZM78.48,86.95\"\n      android:fillColor=\"#062578\"\n      android:fillType=\"evenOdd\"\n      android:strokeColor=\"#00000000\"/>\n</vector>\n"
  },
  {
    "path": "android/MLCEngineExample/app/src/main/res/values/colors.xml",
    "content": "<?xml version=\"1.0\" encoding=\"utf-8\"?>\n<resources>\n    <color name=\"purple_200\">#FFBB86FC</color>\n    <color name=\"purple_500\">#FF6200EE</color>\n    <color name=\"purple_700\">#FF3700B3</color>\n    <color name=\"teal_200\">#FF03DAC5</color>\n    <color name=\"teal_700\">#FF018786</color>\n    <color name=\"black\">#FF000000</color>\n    <color name=\"white\">#FFFFFFFF</color>\n</resources>\n"
  },
  {
    "path": "android/MLCEngineExample/app/src/main/res/values/strings.xml",
    "content": "<resources>\n    <string name=\"app_name\">MLCEngineExample</string>\n</resources>\n"
  },
  {
    "path": "android/MLCEngineExample/app/src/main/res/values/themes.xml",
    "content": "<?xml version=\"1.0\" encoding=\"utf-8\"?>\n<resources>\n\n    <style name=\"Theme.MLCEngineExample\" parent=\"android:Theme.Material.Light\" />\n\n</resources>\n"
  },
  {
    "path": "android/MLCEngineExample/app/src/main/res/xml/backup_rules.xml",
    "content": "<?xml version=\"1.0\" encoding=\"utf-8\"?><!--\n   Sample backup rules file; uncomment and customize as necessary.\n   See https://developer.android.com/guide/topics/data/autobackup\n   for details.\n   Note: This file is ignored for devices older that API 31\n   See https://developer.android.com/about/versions/12/backup-restore\n-->\n<full-backup-content>\n    <!--\n   <include domain=\"sharedpref\" path=\".\"/>\n   <exclude domain=\"sharedpref\" path=\"device.xml\"/>\n-->\n</full-backup-content>\n"
  },
  {
    "path": "android/MLCEngineExample/app/src/main/res/xml/data_extraction_rules.xml",
    "content": "<?xml version=\"1.0\" encoding=\"utf-8\"?><!--\n   Sample data extraction rules file; uncomment and customize as necessary.\n   See https://developer.android.com/about/versions/12/backup-restore#xml-changes\n   for details.\n-->\n<data-extraction-rules>\n    <cloud-backup>\n        <!-- TODO: Use <include> and <exclude> to control what is backed up.\n        <include .../>\n        <exclude .../>\n        -->\n    </cloud-backup>\n    <!--\n    <device-transfer>\n        <include .../>\n        <exclude .../>\n    </device-transfer>\n    -->\n</data-extraction-rules>\n"
  },
  {
    "path": "android/MLCEngineExample/build.gradle",
    "content": "plugins {\n    id 'com.android.application' version '8.2.0' apply false\n    id 'com.android.library' version '8.2.0' apply false\n    id 'org.jetbrains.kotlin.android' version '1.8.10' apply false\n}\n"
  },
  {
    "path": "android/MLCEngineExample/bundle_weight.py",
    "content": "import argparse\nimport os\nimport subprocess\nfrom pathlib import Path\n\nfrom mlc_llm.support import logging\n\nlogging.enable_logging()\nlogger = logging.getLogger(__name__)\n\n\ndef main(apk_path: Path, package_output_path: Path):\n    \"\"\"Push weights to the android device with adb\"\"\"\n    # - Install the apk on device.\n    logger.info('Install apk \"%s\" to device', str(apk_path.absolute()))\n    subprocess.run([\"adb\", \"install\", str(apk_path)], check=True, env=os.environ)\n    # - Create the weight directory for the app.\n    device_weihgt_dir = \"/storage/emulated/0/Android/data/ai.mlc.mlcengineexample/files/\"\n    logger.info('Creating directory \"%s\" on device', device_weihgt_dir)\n    subprocess.run(\n        [\"adb\", \"shell\", \"mkdir\", \"-p\", device_weihgt_dir],\n        check=True,\n        env=os.environ,\n    )\n    for model_weight_dir in (package_output_path / \"bundle\").iterdir():\n        if model_weight_dir.is_dir():\n            src_path = str(model_weight_dir.absolute())\n            dst_path = \"/data/local/tmp/\" + model_weight_dir.name\n            logger.info('Pushing local weights \"%s\" to device location \"%s\"', src_path, dst_path)\n            subprocess.run([\"adb\", \"push\", src_path, dst_path], check=True, env=os.environ)\n\n            src_path = dst_path\n            dst_path = \"/storage/emulated/0/Android/data/ai.mlc.mlcengineexample/files/\"\n            logger.info('Move weights from \"%s\" to \"%s\"', src_path, dst_path)\n            subprocess.run([\"adb\", \"shell\", \"mv\", src_path, dst_path], check=True, env=os.environ)\n    logger.info(\"All finished.\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"MLC LLM Android Weight Bundle\")\n\n    def _parse_apk_path(path: str) -> Path:\n        path = Path(path)\n        if not path.exists():\n            raise ValueError(\n                f\"Path {str(path)} is expected to be an apk file, but the file does not exist.\"\n            )\n        if not path.is_file():\n            raise ValueError(f\"Path {str(path)} is expected to be an apk file.\")\n        return path\n\n    parser.add_argument(\n        \"--apk-path\",\n        type=_parse_apk_path,\n        default=\"app/release/app-release.apk\",\n        help=\"The path to generated MLCEngineExample apk file.\",\n    )\n    parser.add_argument(\n        \"--package-output-path\",\n        type=Path,\n        default=\"dist\",\n        help='The path to the output directory of \"mlc_llm package\".',\n    )\n    args = parser.parse_args()\n    main(args.apk_path, args.package_output_path)\n"
  },
  {
    "path": "android/MLCEngineExample/gradle/wrapper/gradle-wrapper.properties",
    "content": "#Thu Jan 25 10:19:50 EST 2024\ndistributionBase=GRADLE_USER_HOME\ndistributionPath=wrapper/dists\ndistributionUrl=https\\://services.gradle.org/distributions/gradle-8.5-bin.zip\nzipStoreBase=GRADLE_USER_HOME\nzipStorePath=wrapper/dists\n"
  },
  {
    "path": "android/MLCEngineExample/gradle.properties",
    "content": "# Project-wide Gradle settings.\n# IDE (e.g. Android Studio) users:\n# Gradle settings configured through the IDE *will override*\n# any settings specified in this file.\n# For more details on how to configure your build environment visit\n# http://www.gradle.org/docs/current/userguide/build_environment.html\n# Specifies the JVM arguments used for the daemon process.\n# The setting is particularly useful for tweaking memory settings.\norg.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8\n# When configured, Gradle will run in incubating parallel mode.\n# This option should only be used with decoupled projects. More details, visit\n# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects\n# org.gradle.parallel=true\n# AndroidX package structure to make it clearer which packages are bundled with the\n# Android operating system, and which are packaged with your app's APK\n# https://developer.android.com/topic/libraries/support-library/androidx-rn\nandroid.useAndroidX=true\n# Kotlin code style for this project: \"official\" or \"obsolete\":\nkotlin.code.style=official\n# Enables namespacing of each library's R class so that its R class includes only the\n# resources declared in the library itself and none from the library's dependencies,\n# thereby reducing the size of the R class for that library\nandroid.nonTransitiveRClass=true\n"
  },
  {
    "path": "android/MLCEngineExample/gradlew",
    "content": "#!/usr/bin/env sh\n\n#\n# Copyright 2015 the original author or authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#      https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n\n##############################################################################\n##\n##  Gradle start up script for UN*X\n##\n##############################################################################\n\n# Attempt to set APP_HOME\n# Resolve links: $0 may be a link\nPRG=\"$0\"\n# Need this for relative symlinks.\nwhile [ -h \"$PRG\" ] ; do\n    ls=`ls -ld \"$PRG\"`\n    link=`expr \"$ls\" : '.*-> \\(.*\\)$'`\n    if expr \"$link\" : '/.*' > /dev/null; then\n        PRG=\"$link\"\n    else\n        PRG=`dirname \"$PRG\"`\"/$link\"\n    fi\ndone\nSAVED=\"`pwd`\"\ncd \"`dirname \\\"$PRG\\\"`/\" >/dev/null\nAPP_HOME=\"`pwd -P`\"\ncd \"$SAVED\" >/dev/null\n\nAPP_NAME=\"Gradle\"\nAPP_BASE_NAME=`basename \"$0\"`\n\n# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.\nDEFAULT_JVM_OPTS='\"-Xmx64m\" \"-Xms64m\"'\n\n# Use the maximum available, or set MAX_FD != -1 to use that value.\nMAX_FD=\"maximum\"\n\nwarn () {\n    echo \"$*\"\n}\n\ndie () {\n    echo\n    echo \"$*\"\n    echo\n    exit 1\n}\n\n# OS specific support (must be 'true' or 'false').\ncygwin=false\nmsys=false\ndarwin=false\nnonstop=false\ncase \"`uname`\" in\n  CYGWIN* )\n    cygwin=true\n    ;;\n  Darwin* )\n    darwin=true\n    ;;\n  MINGW* )\n    msys=true\n    ;;\n  NONSTOP* )\n    nonstop=true\n    ;;\nesac\n\nCLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar\n\n\n# Determine the Java command to use to start the JVM.\nif [ -n \"$JAVA_HOME\" ] ; then\n    if [ -x \"$JAVA_HOME/jre/sh/java\" ] ; then\n        # IBM's JDK on AIX uses strange locations for the executables\n        JAVACMD=\"$JAVA_HOME/jre/sh/java\"\n    else\n        JAVACMD=\"$JAVA_HOME/bin/java\"\n    fi\n    if [ ! -x \"$JAVACMD\" ] ; then\n        die \"ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME\n\nPlease set the JAVA_HOME variable in your environment to match the\nlocation of your Java installation.\"\n    fi\nelse\n    JAVACMD=\"java\"\n    which java >/dev/null 2>&1 || die \"ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.\n\nPlease set the JAVA_HOME variable in your environment to match the\nlocation of your Java installation.\"\nfi\n\n# Increase the maximum file descriptors if we can.\nif [ \"$cygwin\" = \"false\" -a \"$darwin\" = \"false\" -a \"$nonstop\" = \"false\" ] ; then\n    MAX_FD_LIMIT=`ulimit -H -n`\n    if [ $? -eq 0 ] ; then\n        if [ \"$MAX_FD\" = \"maximum\" -o \"$MAX_FD\" = \"max\" ] ; then\n            MAX_FD=\"$MAX_FD_LIMIT\"\n        fi\n        ulimit -n $MAX_FD\n        if [ $? -ne 0 ] ; then\n            warn \"Could not set maximum file descriptor limit: $MAX_FD\"\n        fi\n    else\n        warn \"Could not query maximum file descriptor limit: $MAX_FD_LIMIT\"\n    fi\nfi\n\n# For Darwin, add options to specify how the application appears in the dock\nif $darwin; then\n    GRADLE_OPTS=\"$GRADLE_OPTS \\\"-Xdock:name=$APP_NAME\\\" \\\"-Xdock:icon=$APP_HOME/media/gradle.icns\\\"\"\nfi\n\n# For Cygwin or MSYS, switch paths to Windows format before running java\nif [ \"$cygwin\" = \"true\" -o \"$msys\" = \"true\" ] ; then\n    APP_HOME=`cygpath --path --mixed \"$APP_HOME\"`\n    CLASSPATH=`cygpath --path --mixed \"$CLASSPATH\"`\n\n    JAVACMD=`cygpath --unix \"$JAVACMD\"`\n\n    # We build the pattern for arguments to be converted via cygpath\n    ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`\n    SEP=\"\"\n    for dir in $ROOTDIRSRAW ; do\n        ROOTDIRS=\"$ROOTDIRS$SEP$dir\"\n        SEP=\"|\"\n    done\n    OURCYGPATTERN=\"(^($ROOTDIRS))\"\n    # Add a user-defined pattern to the cygpath arguments\n    if [ \"$GRADLE_CYGPATTERN\" != \"\" ] ; then\n        OURCYGPATTERN=\"$OURCYGPATTERN|($GRADLE_CYGPATTERN)\"\n    fi\n    # Now convert the arguments - kludge to limit ourselves to /bin/sh\n    i=0\n    for arg in \"$@\" ; do\n        CHECK=`echo \"$arg\"|egrep -c \"$OURCYGPATTERN\" -`\n        CHECK2=`echo \"$arg\"|egrep -c \"^-\"`                                 ### Determine if an option\n\n        if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then                    ### Added a condition\n            eval `echo args$i`=`cygpath --path --ignore --mixed \"$arg\"`\n        else\n            eval `echo args$i`=\"\\\"$arg\\\"\"\n        fi\n        i=`expr $i + 1`\n    done\n    case $i in\n        0) set -- ;;\n        1) set -- \"$args0\" ;;\n        2) set -- \"$args0\" \"$args1\" ;;\n        3) set -- \"$args0\" \"$args1\" \"$args2\" ;;\n        4) set -- \"$args0\" \"$args1\" \"$args2\" \"$args3\" ;;\n        5) set -- \"$args0\" \"$args1\" \"$args2\" \"$args3\" \"$args4\" ;;\n        6) set -- \"$args0\" \"$args1\" \"$args2\" \"$args3\" \"$args4\" \"$args5\" ;;\n        7) set -- \"$args0\" \"$args1\" \"$args2\" \"$args3\" \"$args4\" \"$args5\" \"$args6\" ;;\n        8) set -- \"$args0\" \"$args1\" \"$args2\" \"$args3\" \"$args4\" \"$args5\" \"$args6\" \"$args7\" ;;\n        9) set -- \"$args0\" \"$args1\" \"$args2\" \"$args3\" \"$args4\" \"$args5\" \"$args6\" \"$args7\" \"$args8\" ;;\n    esac\nfi\n\n# Escape application args\nsave () {\n    for i do printf %s\\\\n \"$i\" | sed \"s/'/'\\\\\\\\''/g;1s/^/'/;\\$s/\\$/' \\\\\\\\/\" ; done\n    echo \" \"\n}\nAPP_ARGS=`save \"$@\"`\n\n# Collect all arguments for the java command, following the shell quoting and substitution rules\neval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS \"\\\"-Dorg.gradle.appname=$APP_BASE_NAME\\\"\" -classpath \"\\\"$CLASSPATH\\\"\" org.gradle.wrapper.GradleWrapperMain \"$APP_ARGS\"\n\nexec \"$JAVACMD\" \"$@\"\n"
  },
  {
    "path": "android/MLCEngineExample/gradlew.bat",
    "content": "@rem\n@rem Copyright 2015 the original author or authors.\n@rem\n@rem Licensed under the Apache License, Version 2.0 (the \"License\");\n@rem you may not use this file except in compliance with the License.\n@rem You may obtain a copy of the License at\n@rem\n@rem      https://www.apache.org/licenses/LICENSE-2.0\n@rem\n@rem Unless required by applicable law or agreed to in writing, software\n@rem distributed under the License is distributed on an \"AS IS\" BASIS,\n@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n@rem See the License for the specific language governing permissions and\n@rem limitations under the License.\n@rem\n\n@if \"%DEBUG%\" == \"\" @echo off\n@rem ##########################################################################\n@rem\n@rem  Gradle startup script for Windows\n@rem\n@rem ##########################################################################\n\n@rem Set local scope for the variables with windows NT shell\nif \"%OS%\"==\"Windows_NT\" setlocal\n\nset DIRNAME=%~dp0\nif \"%DIRNAME%\" == \"\" set DIRNAME=.\nset APP_BASE_NAME=%~n0\nset APP_HOME=%DIRNAME%\n\n@rem Resolve any \".\" and \"..\" in APP_HOME to make it shorter.\nfor %%i in (\"%APP_HOME%\") do set APP_HOME=%%~fi\n\n@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.\nset DEFAULT_JVM_OPTS=\"-Xmx64m\" \"-Xms64m\"\n\n@rem Find java.exe\nif defined JAVA_HOME goto findJavaFromJavaHome\n\nset JAVA_EXE=java.exe\n%JAVA_EXE% -version >NUL 2>&1\nif \"%ERRORLEVEL%\" == \"0\" goto execute\n\necho.\necho ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.\necho.\necho Please set the JAVA_HOME variable in your environment to match the\necho location of your Java installation.\n\ngoto fail\n\n:findJavaFromJavaHome\nset JAVA_HOME=%JAVA_HOME:\"=%\nset JAVA_EXE=%JAVA_HOME%/bin/java.exe\n\nif exist \"%JAVA_EXE%\" goto execute\n\necho.\necho ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%\necho.\necho Please set the JAVA_HOME variable in your environment to match the\necho location of your Java installation.\n\ngoto fail\n\n:execute\n@rem Setup the command line\n\nset CLASSPATH=%APP_HOME%\\gradle\\wrapper\\gradle-wrapper.jar\n\n\n@rem Execute Gradle\n\"%JAVA_EXE%\" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% \"-Dorg.gradle.appname=%APP_BASE_NAME%\" -classpath \"%CLASSPATH%\" org.gradle.wrapper.GradleWrapperMain %*\n\n:end\n@rem End local scope for the variables with windows NT shell\nif \"%ERRORLEVEL%\"==\"0\" goto mainEnd\n\n:fail\nrem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of\nrem the _cmd.exe /c_ return code!\nif  not \"\" == \"%GRADLE_EXIT_CONSOLE%\" exit 1\nexit /b 1\n\n:mainEnd\nif \"%OS%\"==\"Windows_NT\" endlocal\n\n:omega\n"
  },
  {
    "path": "android/MLCEngineExample/mlc-package-config.json",
    "content": "{\n    \"device\": \"android\",\n    \"model_list\": [\n        {\n            \"model\": \"HF://mlc-ai/phi-2-q4f16_1-MLC\",\n            \"estimated_vram_bytes\": 2036816936,\n            \"model_id\": \"phi-2-q4f16_1-MLC\",\n            \"overrides\": {\n                \"prefill_chunk_size\": 1024\n            }\n        }\n    ]\n}\n"
  },
  {
    "path": "android/MLCEngineExample/settings.gradle",
    "content": "pluginManagement {\n    repositories {\n        google()\n        mavenCentral()\n        gradlePluginPortal()\n    }\n}\ndependencyResolutionManagement {\n    repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)\n    repositories {\n        google()\n        mavenCentral()\n    }\n}\nrootProject.name = \"MLCEngineExample\"\ninclude ':app'\ninclude ':mlc4j'\nproject(':mlc4j').projectDir = file('dist/lib/mlc4j')\n"
  },
  {
    "path": "android/README.md",
    "content": "# MLC-LLM Android\n\n[Documentation page](https://llm.mlc.ai/docs/deploy/android.html)\n"
  },
  {
    "path": "android/mlc4j/.gitignore",
    "content": "/build\n"
  },
  {
    "path": "android/mlc4j/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.18)\n\nproject(mlc-chat C CXX)\n\nset(ANDROID_DIR ${CMAKE_CURRENT_LIST_DIR})\nset(ANDROID_BIN_DIR ${CMAKE_CURRENT_BINARY_DIR})\n\nset(MLC_LLM_DIR ${ANDROID_DIR}/../..)\nset(MLC_LLM_BINARY_DIR mlc_llm)\nset(MLC_LLM_COMPILE_DEFS TVM_LOG_CUSTOMIZE=1)\nadd_subdirectory(${MLC_LLM_DIR} ${MLC_LLM_BINARY_DIR} EXCLUDE_FROM_ALL)\n\nif(NOT DEFINED TVM_SOURCE_DIR)\n  set(TVM_SOURCE_DIR ${MLC_LLM_DIR}/3rdparty/tvm)\nendif(NOT DEFINED TVM_SOURCE_DIR)\nmessage(STATUS \"TVM_SOURCE_DIR: ${TVM_SOURCE_DIR}\")\n\nfind_package(Java REQUIRED)\ninclude(UseJava)\n\nfind_package(JNI)\nif(JNI_FOUND)\n  message(STATUS \"JNI_INCLUDE_DIRS=${JNI_INCLUDE_DIRS}\")\nelse()\n  message(STATUS \"Try to find jni directly from android env\")\n  # try to find JNI_LIBRARY\n  find_path(JNI_INCLUDE_DIRS NAMES \"jni.h\")\n  message(STATUS \"JNI_INCLUDE_DIRS=${JNI_INCLUDE_DIRS}\")\nendif()\n\nfile(GLOB_RECURSE javasources\n     ${TVM_SOURCE_DIR}/jvm/core/src/main/java/org/apache/tvm/*.java\n     ${ANDROID_DIR}/src/java/*.java)\nset(JNI_HEADER ${CMAKE_BINARY_DIR}/jni_header)\nadd_jar(tvm4j_core ${javasources} GENERATE_NATIVE_HEADERS tvm4jheaders\n        DESTINATION ${JNI_HEADER})\n\nadd_custom_command(\n  TARGET tvm4j_core\n  POST_BUILD\n  COMMAND ${CMAKE_COMMAND} -E copy ${JNI_HEADER}/org_apache_tvm_LibInfo.h\n          ${JNI_HEADER}/org_apache_tvm_native_c_api.h)\n\nadd_library(model_android STATIC IMPORTED)\nset_target_properties(\n  model_android PROPERTIES IMPORTED_LOCATION\n                           ${ANDROID_BIN_DIR}/lib/libmodel_android.a)\n\nadd_library(\n  tvm4j_runtime_packed SHARED\n  ${TVM_SOURCE_DIR}/jvm/native/src/main/native/org_apache_tvm_native_c_api.cc)\nset(MLC_LLM_COMPILE_DEFS ${MLC_LLM_COMPILE_DEFS}\n                         TVM_SOURCE_DIR=${TVM_SOURCE_DIR})\n\ntarget_include_directories(\n  tvm4j_runtime_packed\n  PUBLIC ${JNI_INCLUDE_DIRS}\n         ${JNI_HEADER}\n         ${ANDROID_DIR}/src/cpp\n         ${TVM_SOURCE_DIR}/3rdparty/tvm-ffi/3rdparty/dlpack/include\n         ${TVM_SOURCE_DIR}/3rdparty/OpenCL-Headers\n         ${TVM_SOURCE_DIR}/include\n         ${TVM_SOURCE_DIR}/src\n         ${TVM_SOURCE_DIR}/3rdparty/tvm-ffi/include\n         ${TVM_SOURCE_DIR}/3rdparty/tvm-ffi/src)\ntarget_compile_definitions(tvm4j_runtime_packed PUBLIC ${MLC_LLM_COMPILE_DEFS})\ntarget_compile_definitions(\n  tvm4j_runtime_packed\n  PUBLIC TVM_VM_ENABLE_PROFILER=0\n  PUBLIC TVM_FFI_USE_LIBBACKTRACE=0\n  PUBLIC TVM_FFI_BACKTRACE_ON_SEGFAULT=0)\n\nset(MLC_ENABLE_SENTENCEPIECE_TOKENIZER OFF)\ntarget_link_libraries(\n  tvm4j_runtime_packed\n  tokenizers_c\n  tokenizers_cpp\n  log\n  -Wl,--whole-archive\n  mlc_llm_static\n  model_android\n  -Wl,--no-whole-archive)\n\ntarget_compile_definitions(tvm4j_runtime_packed PUBLIC TVM4J_ANDROID)\nadd_dependencies(tvm4j_runtime_packed tvm4j_core)\n\ntarget_compile_definitions(mlc_llm_objs PUBLIC MLC_SINGLE_GPU_ONLY)\n\ninstall_jar(tvm4j_core output)\ninstall(TARGETS tvm4j_runtime_packed LIBRARY DESTINATION output/${ANDROID_ABI})\n"
  },
  {
    "path": "android/mlc4j/build.gradle",
    "content": "plugins {\n    id 'com.android.library'\n    id 'org.jetbrains.kotlin.android'\n    id 'org.jetbrains.kotlin.plugin.serialization' version '1.8.0'\n}\n\nandroid {\n    namespace 'ai.mlc.mlcllm'\n    compileSdk 34\n\n    defaultConfig {\n        minSdk 22\n    }\n    compileOptions {\n        sourceCompatibility JavaVersion.VERSION_1_8\n        targetCompatibility JavaVersion.VERSION_1_8\n    }\n    kotlinOptions {\n        jvmTarget = '1.8'\n    }\n    sourceSets {\n        main {\n            jniLibs.srcDirs = ['output']\n        }\n    }\n}\n\ndependencies {\n    implementation fileTree(dir: 'output', include: ['*.jar'])\n    implementation 'androidx.core:core-ktx:1.9.0'\n    implementation 'androidx.appcompat:appcompat:1.6.1'\n    implementation 'com.google.android.material:material:1.10.0'\n    implementation 'org.jetbrains.kotlinx:kotlinx-serialization-json:1.6.3'\n}\n"
  },
  {
    "path": "android/mlc4j/prepare_libs.py",
    "content": "\"\"\"The build script for mlc4j (MLC LLM and tvm4j)\"\"\"\n\nimport argparse\nimport json\nimport os\nimport subprocess\nimport sys\nfrom pathlib import Path\n\nfrom mlc_llm.support import logging\n\nlogging.enable_logging()\nlogger = logging.getLogger(__name__)\n\n\ndef run_cmake(mlc4j_path: Path):\n    if \"ANDROID_NDK\" not in os.environ:\n        raise ValueError(\n            f'Environment variable \"ANDROID_NDK\" is required but not found.'\n            \"Please follow https://llm.mlc.ai/docs/deploy/android.html to properly \"\n            'specify \"ANDROID_NDK\".'\n        )\n    logger.info(\"Running cmake\")\n    # use pathlib so it is cross platform\n    android_ndk_path = (\n        Path(os.environ[\"ANDROID_NDK\"]) / \"build\" / \"cmake\" / \"android.toolchain.cmake\"\n    )\n    cmd = [\n        \"cmake\",\n        str(mlc4j_path),\n        \"-DCMAKE_BUILD_TYPE=Release\",\n        f\"-DCMAKE_TOOLCHAIN_FILE={str(android_ndk_path)}\",\n        \"-DCMAKE_INSTALL_PREFIX=.\",\n        '-DCMAKE_CXX_FLAGS=\"-O3\"',\n        \"-DANDROID_ABI=arm64-v8a\",\n        \"-DANDROID_NATIVE_API_LEVEL=android-24\",\n        \"-DANDROID_PLATFORM=android-24\",\n        \"-DCMAKE_FIND_ROOT_PATH_MODE_PACKAGE=ON\",\n        \"-DANDROID_STL=c++_static\",\n        \"-DUSE_HEXAGON_SDK=OFF\",\n        \"-DMLC_LLM_INSTALL_STATIC_LIB=ON\",\n        \"-DCMAKE_SKIP_INSTALL_ALL_DEPENDENCY=ON\",\n        \"-DUSE_OPENCL=ON\",\n        \"-DUSE_OPENCL_ENABLE_HOST_PTR=ON\",\n        \"-DUSE_CUSTOM_LOGGING=ON\",\n        \"-DTVM_FFI_USE_LIBBACKTRACE=OFF\",\n        \"-DTVM_FFI_BACKTRACE_ON_SEGFAULT=OFF\",\n    ]\n\n    if sys.platform == \"win32\":\n        logger.info(\"Using ninja in windows, make sure you installed ninja in conda\")\n        cmd += [\"-G\", \"Ninja\"]\n    subprocess.run(cmd, check=True, env=os.environ)\n\n\ndef run_cmake_build():\n    logger.info(\"Running cmake build\")\n    cmd = [\n        \"cmake\",\n        \"--build\",\n        \".\",\n        \"--target\",\n        \"tvm4j_runtime_packed\",\n        \"--config\",\n        \"release\",\n        f\"-j{os.cpu_count()}\",\n    ]\n    subprocess.run(cmd, check=True, env=os.environ)\n\n\ndef run_cmake_install():\n    logger.info(\"Running cmake install\")\n    cmd = [\n        \"cmake\",\n        \"--build\",\n        \".\",\n        \"--target\",\n        \"install\",\n        \"--config\",\n        \"release\",\n        f\"-j{os.cpu_count()}\",\n    ]\n    subprocess.run(cmd, check=True, env=os.environ)\n\n\ndef main(mlc_llm_source_dir: Path):\n    # - Setup rust.\n    subprocess.run([\"rustup\", \"target\", \"add\", \"aarch64-linux-android\"], check=True, env=os.environ)\n\n    # - Build MLC LLM and tvm4j.\n    build_path = Path(\"build\")\n    os.makedirs(build_path / \"lib\", exist_ok=True)\n    logger.info('Entering \"%s\" for MLC LLM and tvm4j build.', os.path.abspath(build_path))\n    os.chdir(build_path)\n    # Generate config.cmake if TVM Home is set.\n    if \"TVM_SOURCE_DIR\" in os.environ:\n        logger.info('Set TVM_SOURCE_DIR to \"%s\"', os.environ[\"TVM_SOURCE_DIR\"])\n        with open(\"config.cmake\", \"w\", encoding=\"utf-8\") as file:\n            # We use \"json.dumps\" to escape backslashes and quotation marks\n            tvm_source_dir_str_with_escape = json.dumps(os.environ[\"TVM_SOURCE_DIR\"])\n            print(\"set(TVM_SOURCE_DIR %s)\" % tvm_source_dir_str_with_escape, file=file)\n\n    # - Run cmake, build and install\n    run_cmake(mlc_llm_source_dir / \"android\" / \"mlc4j\")\n    run_cmake_build()\n    run_cmake_install()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"MLC LLM Android Lib Preparation\")\n\n    parser.add_argument(\n        \"--mlc-llm-source-dir\",\n        type=Path,\n        default=os.environ.get(\"MLC_LLM_SOURCE_DIR\", None),\n        help=\"The path to MLC LLM source\",\n    )\n    parsed = parser.parse_args()\n    if parsed.mlc_llm_source_dir is None:\n        parsed.mlc_llm_source_dir = Path(os.path.abspath(os.path.curdir)).parent.parent\n    os.environ[\"MLC_LLM_SOURCE_DIR\"] = str(parsed.mlc_llm_source_dir)\n    main(parsed.mlc_llm_source_dir)\n"
  },
  {
    "path": "android/mlc4j/src/cpp/tvm_runtime.h",
    "content": "#define TVM_USE_LIBBACKTRACE 0\n\n#include <android/log.h>\n#include <dlfcn.h>\n#include <tvm/runtime/logging.h>\n\n#include <ffi/backtrace.cc>\n#include <ffi/container.cc>\n#include <ffi/dtype.cc>\n#include <ffi/error.cc>\n#include <ffi/extra/env_c_api.cc>\n#include <ffi/extra/env_context.cc>\n#include <ffi/extra/json_parser.cc>\n#include <ffi/extra/json_writer.cc>\n#include <ffi/extra/library_module.cc>\n#include <ffi/extra/library_module_dynamic_lib.cc>\n#include <ffi/extra/library_module_system_lib.cc>\n#include <ffi/extra/module.cc>\n#include <ffi/function.cc>\n#include <ffi/object.cc>\n#include <runtime/cpu_device_api.cc>\n#include <runtime/device_api.cc>\n#include <runtime/file_utils.cc>\n#include <runtime/logging.cc>\n#include <runtime/memory/memory_manager.cc>\n#include <runtime/module.cc>\n#include <runtime/nvtx.cc>\n#include <runtime/opencl/opencl_device_api.cc>\n#include <runtime/opencl/opencl_module.cc>\n#include <runtime/opencl/opencl_wrapper/opencl_wrapper.cc>\n#include <runtime/profiling.cc>\n#include <runtime/source_utils.cc>\n#include <runtime/tensor.cc>\n#include <runtime/thread_pool.cc>\n#include <runtime/threading_backend.cc>\n#include <runtime/vm/attn_backend.cc>\n#include <runtime/vm/builtin.cc>\n#include <runtime/vm/bytecode.cc>\n#include <runtime/vm/executable.cc>\n#include <runtime/vm/kv_state.cc>\n#include <runtime/vm/paged_kv_cache.cc>\n#include <runtime/vm/rnn_state.cc>\n#include <runtime/vm/tensor_cache_support.cc>\n#include <runtime/vm/vm.cc>\n#include <runtime/workspace_pool.cc>\n\nstatic_assert(TVM_LOG_CUSTOMIZE == 1, \"TVM_LOG_CUSTOMIZE must be 1\");\n\nnamespace tvm {\nnamespace runtime {\nnamespace detail {\n// Override logging mechanism\n[[noreturn]] void LogFatalImpl(const std::string& file, int lineno, const std::string& message) {\n  std::string m = file + \":\" + std::to_string(lineno) + \": \" + message;\n  __android_log_write(ANDROID_LOG_FATAL, \"TVM_RUNTIME\", m.c_str());\n  throw InternalError(file, lineno, message);\n}\nvoid LogMessageImpl(const std::string& file, int lineno, int level, const std::string& message) {\n  std::string m = file + \":\" + std::to_string(lineno) + \": \" + message;\n  __android_log_write(ANDROID_LOG_DEBUG + level, \"TVM_RUNTIME\", m.c_str());\n}\n\n}  // namespace detail\n}  // namespace runtime\n}  // namespace tvm\n"
  },
  {
    "path": "android/mlc4j/src/main/AndroidManifest.xml",
    "content": "<?xml version=\"1.0\" encoding=\"utf-8\"?>\n<manifest xmlns:android=\"http://schemas.android.com/apk/res/android\">\n\n</manifest>\n"
  },
  {
    "path": "android/mlc4j/src/main/java/ai/mlc/mlcllm/JSONFFIEngine.java",
    "content": "package ai.mlc.mlcllm;\n\nimport org.apache.tvm.Device;\nimport org.apache.tvm.Function;\nimport org.apache.tvm.Module;\nimport org.apache.tvm.TVMValue;\nimport android.util.Log;\n\npublic class JSONFFIEngine {\n    private Module jsonFFIEngine;\n    private Function initBackgroundEngineFunc;\n    private Function reloadFunc;\n    private Function unloadFunc;\n    private Function resetFunc;\n    private Function chatCompletionFunc;\n    private Function abortFunc;\n    private Function getLastErrorFunc;\n    private Function runBackgroundLoopFunc;\n    private Function runBackgroundStreamBackLoopFunc;\n    private Function exitBackgroundLoopFunc;\n    private Function requestStreamCallback;\n\n    public JSONFFIEngine() {\n        Function createFunc = Function.getFunction(\"mlc.json_ffi.CreateJSONFFIEngine\");\n        assert createFunc != null;\n        jsonFFIEngine = createFunc.invoke().asModule();\n        initBackgroundEngineFunc = jsonFFIEngine.getFunction(\"init_background_engine\");\n        reloadFunc = jsonFFIEngine.getFunction(\"reload\");\n        unloadFunc = jsonFFIEngine.getFunction(\"unload\");\n        resetFunc = jsonFFIEngine.getFunction(\"reset\");\n        chatCompletionFunc = jsonFFIEngine.getFunction(\"chat_completion\");\n        abortFunc = jsonFFIEngine.getFunction(\"abort\");\n        getLastErrorFunc = jsonFFIEngine.getFunction(\"get_last_error\");\n        runBackgroundLoopFunc = jsonFFIEngine.getFunction(\"run_background_loop\");\n        runBackgroundStreamBackLoopFunc = jsonFFIEngine.getFunction(\"run_background_stream_back_loop\");\n        exitBackgroundLoopFunc = jsonFFIEngine.getFunction(\"exit_background_loop\");\n    }\n\n    public void initBackgroundEngine(KotlinFunction callback) {\n        Device device = Device.opencl();\n\n        requestStreamCallback = Function.convertFunc(new Function.Callback() {\n            @Override\n            public Object invoke(TVMValue... args) {\n                final String chatCompletionStreamResponsesJSONStr = args[0].asString();\n                callback.invoke(chatCompletionStreamResponsesJSONStr);\n                return 1;\n            }\n        });\n\n        initBackgroundEngineFunc.pushArg(device.deviceType).pushArg(device.deviceId).pushArg(requestStreamCallback)\n                .invoke();\n    }\n\n    public void reload(String engineConfigJSONStr) {\n        reloadFunc.pushArg(engineConfigJSONStr).invoke();\n    }\n\n    public void chatCompletion(String requestJSONStr, String requestId) {\n        chatCompletionFunc.pushArg(requestJSONStr).pushArg(requestId).invoke();\n    }\n\n    public void runBackgroundLoop() {\n        runBackgroundLoopFunc.invoke();\n    }\n\n    public void runBackgroundStreamBackLoop() {\n        runBackgroundStreamBackLoopFunc.invoke();\n    }\n\n    public void exitBackgroundLoop() {\n        exitBackgroundLoopFunc.invoke();\n    }\n\n    public void unload() {\n        unloadFunc.invoke();\n    }\n\n    public interface KotlinFunction {\n        void invoke(String arg);\n    }\n\n    public void reset() {\n        resetFunc.invoke();\n    }\n\n}\n"
  },
  {
    "path": "android/mlc4j/src/main/java/ai/mlc/mlcllm/MLCEngine.kt",
    "content": "package ai.mlc.mlcllm\n\nimport ai.mlc.mlcllm.OpenAIProtocol.*\nimport kotlinx.coroutines.GlobalScope\nimport kotlinx.coroutines.channels.Channel\nimport kotlinx.coroutines.channels.ReceiveChannel\nimport kotlinx.coroutines.launch\nimport kotlinx.serialization.json.Json\nimport kotlinx.serialization.encodeToString\nimport kotlinx.serialization.decodeFromString\nimport kotlin.concurrent.thread\nimport java.util.UUID\nimport java.util.logging.Logger\n\nclass BackgroundWorker(private val task: () -> Unit) {\n\n    fun start() {\n        thread(start = true) {\n            task()\n        }\n    }\n}\n\nclass MLCEngine {\n\n    private val state: EngineState\n    private val jsonFFIEngine: JSONFFIEngine\n    val chat: Chat\n    private val threads = mutableListOf<BackgroundWorker>()\n\n    init {\n        state = EngineState()\n        jsonFFIEngine = JSONFFIEngine()\n        chat = Chat(jsonFFIEngine, state)\n\n        jsonFFIEngine.initBackgroundEngine { result ->\n            state.streamCallback(result)\n        }\n\n        val backgroundWorker = BackgroundWorker {\n            Thread.currentThread().priority = Thread.MAX_PRIORITY\n            jsonFFIEngine.runBackgroundLoop()\n        }\n\n        val backgroundStreamBackWorker = BackgroundWorker {\n            jsonFFIEngine.runBackgroundStreamBackLoop()\n        }\n\n        threads.add(backgroundWorker)\n        threads.add(backgroundStreamBackWorker)\n\n        backgroundWorker.start()\n        backgroundStreamBackWorker.start()\n    }\n\n    fun reload(modelPath: String, modelLib: String) {\n        val engineConfig = \"\"\"\n            {\n                \"model\": \"$modelPath\",\n                \"model_lib\": \"system://$modelLib\",\n                \"mode\": \"interactive\"\n            }\n        \"\"\"\n        jsonFFIEngine.reload(engineConfig)\n    }\n\n    fun reset() {\n        jsonFFIEngine.reset()\n    }\n\n    fun unload() {\n        jsonFFIEngine.unload()\n    }\n}\n\ndata class RequestState(\n    val request: ChatCompletionRequest,\n    val continuation: Channel<ChatCompletionStreamResponse>\n)\n\nclass EngineState {\n\n    private val logger = Logger.getLogger(EngineState::class.java.name)\n    private val requestStateMap = mutableMapOf<String, RequestState>()\n\n    suspend fun chatCompletion(\n        jsonFFIEngine: JSONFFIEngine,\n        request: ChatCompletionRequest\n    ): ReceiveChannel<ChatCompletionStreamResponse> {\n        val json = Json { encodeDefaults = true }\n        val jsonRequest = json.encodeToString(request)\n        val requestID = UUID.randomUUID().toString()\n        val channel = Channel<ChatCompletionStreamResponse>(Channel.UNLIMITED)\n\n        requestStateMap[requestID] = RequestState(request, channel)\n\n        jsonFFIEngine.chatCompletion(jsonRequest, requestID)\n\n        return channel\n    }\n\n    fun streamCallback(result: String?) {\n        val json = Json { ignoreUnknownKeys = true }\n        try {\n            val responses: List<ChatCompletionStreamResponse> = json.decodeFromString(result ?: return)\n\n            responses.forEach { res ->\n                val requestState = requestStateMap[res.id] ?: return@forEach\n                GlobalScope.launch {\n\n                    res.usage?.let { finalUsage ->\n                        requestState.request.stream_options?.include_usage?.let { includeUsage ->\n                            if (includeUsage) {\n                                requestState.continuation.send(res)\n                            }\n                        }\n                        requestState.continuation.close()\n                        requestStateMap.remove(res.id)\n                    } ?: run {\n                        val sendResult = requestState.continuation.trySend(res)\n                        if (sendResult.isFailure) {\n                            // Handle the failure case if needed\n                            logger.severe(\"Failed to send the response: ${sendResult.exceptionOrNull()}\")\n                        }\n                    }\n                }\n            }\n        } catch (e: Exception) {\n            logger.severe(\"Kotlin JSON parsing error: $e, jsonsrc=$result\")\n        }\n    }\n}\n\nclass Chat(\n    private val jsonFFIEngine: JSONFFIEngine,\n    private val state: EngineState\n) {\n    val completions = Completions(jsonFFIEngine, state)\n}\n\nclass Completions(\n    private val jsonFFIEngine: JSONFFIEngine,\n    private val state: EngineState\n) {\n\n    suspend fun create(request: ChatCompletionRequest): ReceiveChannel<ChatCompletionStreamResponse> {\n        return state.chatCompletion(jsonFFIEngine, request)\n    }\n\n    suspend fun create(\n        messages: List<ChatCompletionMessage>,\n        model: String? = null,\n        frequency_penalty: Float? = null,\n        presence_penalty: Float? = null,\n        logprobs: Boolean = false,\n        top_logprobs: Int = 0,\n        logit_bias: Map<Int, Float>? = null,\n        max_tokens: Int? = null,\n        n: Int = 1,\n        seed: Int? = null,\n        stop: List<String>? = null,\n        stream: Boolean = true,\n        stream_options: StreamOptions? = null,\n        temperature: Float? = null,\n        top_p: Float? = null,\n        tools: List<ChatTool>? = null,\n        user: String? = null,\n        response_format: ResponseFormat? = null\n    ): ReceiveChannel<ChatCompletionStreamResponse> {\n        if (!stream) {\n            throw IllegalArgumentException(\"Only stream=true is supported in MLCKotlin\")\n        }\n\n        val request = ChatCompletionRequest(\n            messages = messages,\n            model = model,\n            frequency_penalty = frequency_penalty,\n            presence_penalty = presence_penalty,\n            logprobs = logprobs,\n            top_logprobs = top_logprobs,\n            logit_bias = logit_bias,\n            max_tokens = max_tokens,\n            n = n,\n            seed = seed,\n            stop = stop,\n            stream = stream,\n            stream_options = stream_options,\n            temperature = temperature,\n            top_p = top_p,\n            tools = tools,\n            user = user,\n            response_format = response_format\n        )\n        return create(request)\n    }\n}\n"
  },
  {
    "path": "android/mlc4j/src/main/java/ai/mlc/mlcllm/OpenAIProtocol.kt",
    "content": "package ai.mlc.mlcllm\n\nimport kotlinx.serialization.KSerializer\nimport kotlinx.serialization.Serializable\nimport kotlinx.serialization.builtins.ListSerializer\nimport kotlinx.serialization.builtins.MapSerializer\nimport kotlinx.serialization.builtins.serializer\nimport kotlinx.serialization.descriptors.SerialDescriptor\nimport kotlinx.serialization.descriptors.buildClassSerialDescriptor\nimport kotlinx.serialization.encoding.Decoder\nimport kotlinx.serialization.encoding.Encoder\nimport kotlinx.serialization.json.JsonArray\nimport kotlinx.serialization.json.JsonElement\nimport kotlinx.serialization.json.JsonObject\nimport kotlinx.serialization.json.JsonPrimitive\nimport kotlinx.serialization.json.jsonPrimitive\nimport java.util.*\n\n// Data classes for v1/chat/completions\n// API reference: https://platform.openai.com/docs/api-reference/chat/create\n\nclass OpenAIProtocol {\n    @Serializable\n    data class TopLogProbs(\n        val token: String,\n        val logprob: Float,\n        val bytes: List<Int>? = null\n    )\n\n    @Serializable\n    data class LogProbsContent(\n        val token: String,\n        val logprob: Float,\n        var bytes: List<Int>? = null,\n        var top_logprobs: List<TopLogProbs> = listOf()\n    )\n\n    @Serializable\n    data class LogProbs(\n        var content: List<LogProbsContent> = listOf()\n    )\n\n    @Serializable\n    data class ChatFunction(\n        val name: String,\n        var description: String? = null,\n        val parameters: Map<String, String>\n    )\n\n    @Serializable\n    data class ChatTool(\n        val type: String = \"function\",\n        val function: ChatFunction\n    )\n\n    @Serializable\n    data class ChatFunctionCall(\n        val name: String,\n        // NOTE: arguments should be dict str to any codable\n        // for now only allow string output due to typing issues\n        var arguments: Map<String, String>? = null\n    )\n\n    @Serializable\n    data class ChatToolCall(\n        val id: String = UUID.randomUUID().toString(),\n        val type: String = \"function\",\n        val function: ChatFunctionCall\n    )\n\n    @Serializable\n    enum class ChatCompletionRole {\n        system,\n        user,\n        assistant,\n        tool\n    }\n\n    @Serializable(with = ChatCompletionMessageContentSerializer::class)\n    data class ChatCompletionMessageContent(\n        val text: String? = null,\n        val parts: List<Map<String, String>>? = null\n    ) {\n        constructor(text: String) : this(text, null)\n        constructor(parts: List<Map<String, String>>) : this(null, parts)\n\n        fun isText(): Boolean {\n            return text != null\n        }\n\n        fun isParts(): Boolean {\n            return parts != null\n        }\n\n        fun asText(): String {\n            return text ?: (parts?.filter { it[\"type\"] == \"text\" }?.joinToString(\"\") { it[\"text\"] ?: \"\" } ?: \"\")\n        }\n    }\n\n    object ChatCompletionMessageContentSerializer : KSerializer<ChatCompletionMessageContent> {\n        override val descriptor: SerialDescriptor = buildClassSerialDescriptor(\"ChatCompletionMessageContent\") {\n            element(\"text\", String.serializer().descriptor)\n            element(\"parts\", ListSerializer(MapSerializer(String.serializer(), String.serializer())).descriptor)\n        }\n\n        override fun serialize(encoder: Encoder, value: ChatCompletionMessageContent) {\n            if (value.isText()) {\n                encoder.encodeString(value.text!!)\n            } else {\n                encoder.encodeSerializableValue(ListSerializer(MapSerializer(String.serializer(), String.serializer())), value.parts ?: listOf())\n            }\n        }\n\n        override fun deserialize(decoder: Decoder): ChatCompletionMessageContent {\n            return when (val element = decoder.decodeSerializableValue(JsonElement.serializer())) {\n                is JsonArray -> {\n                    val parts = element.map { (it as JsonObject).map { entry -> entry.key to entry.value.jsonPrimitive.content }.toMap() }\n                    ChatCompletionMessageContent(parts)\n                }\n                is JsonPrimitive -> {\n                    ChatCompletionMessageContent(element.content)\n                }\n                else -> throw IllegalStateException(\"Unexpected JsonElement type\")\n            }\n        }\n    }\n\n    @Serializable\n    data class ChatCompletionMessage(\n        val role: ChatCompletionRole,\n        var content: ChatCompletionMessageContent? = null,\n        var name: String? = null,\n        var tool_calls: List<ChatToolCall>? = null,\n        var tool_call_id: String? = null\n    ) {\n        constructor(\n            role: ChatCompletionRole,\n            content: String,\n            name: String? = null,\n            tool_calls: List<ChatToolCall>? = null,\n            tool_call_id: String? = null\n        ) : this(role, ChatCompletionMessageContent(content), name, tool_calls, tool_call_id)\n    }\n\n    @Serializable\n    data class CompletionUsageExtra(\n        val prefill_tokens_per_s: Float? = null,\n        val decode_tokens_per_s: Float? = null,\n        val num_prefill_tokens: Int? = null\n    ) {\n        fun asTextLabel(): String {\n            var outputText = \"\"\n            if (prefill_tokens_per_s != null) {\n                outputText += \"prefill: ${String.format(\"%.1f\", prefill_tokens_per_s)} tok/s\"\n            }\n            if (decode_tokens_per_s != null) {\n                if (outputText.isNotEmpty()) {\n                    outputText += \", \"\n                }\n                outputText += \"decode: ${String.format(\"%.1f\", decode_tokens_per_s)} tok/s\"\n            }\n            return outputText\n        }\n    }\n\n    @Serializable\n    data class CompletionUsage(\n        val prompt_tokens: Int,\n        val completion_tokens: Int,\n        val total_tokens: Int,\n        val extra: CompletionUsageExtra? = null\n    )\n\n    @Serializable\n    data class StreamOptions(\n        val include_usage: Boolean = false\n    )\n\n    @Serializable\n    data class ChatCompletionStreamResponseChoice(\n        var finish_reason: String? = null,\n        val index: Int,\n        val delta: ChatCompletionMessage,\n        var lobprobs: LogProbs? = null\n    )\n\n    @Serializable\n    data class ChatCompletionStreamResponse(\n        val id: String,\n        var choices: List<ChatCompletionStreamResponseChoice> = listOf(),\n        var created: Int? = null,\n        var model: String? = null,\n        val system_fingerprint: String,\n        var `object`: String? = null,\n        val usage: CompletionUsage? = null\n    )\n\n    @Serializable\n    data class ChatCompletionRequest(\n        val messages: List<ChatCompletionMessage>,\n        val model: String? = null,\n        val frequency_penalty: Float? = null,\n        val presence_penalty: Float? = null,\n        val logprobs: Boolean = false,\n        val top_logprobs: Int = 0,\n        val logit_bias: Map<Int, Float>? = null,\n        val max_tokens: Int? = null,\n        val n: Int = 1,\n        val seed: Int? = null,\n        val stop: List<String>? = null,\n        val stream: Boolean = true,\n        val stream_options: StreamOptions? = null,\n        val temperature: Float? = null,\n        val top_p: Float? = null,\n        val tools: List<ChatTool>? = null,\n        val user: String? = null,\n        val response_format: ResponseFormat? = null\n    )\n\n    @Serializable\n    data class ResponseFormat(\n        val type: String,\n        val schema: String? = null\n    )\n}\n"
  },
  {
    "path": "ci/bash.sh",
    "content": "#!/usr/bin/env bash\n\nif [ \"$#\" -lt 1 ]; then\n    echo \"Usage: ci/bash.sh <CONTAINER_NAME> -e key value -v key value [COMMAND]\"\n    exit -1\nfi\n\nDOCKER_IMAGE_NAME=(\"$1\")\nSCRIPT_DIR=\"$(cd \"$(dirname \"${BASH_SOURCE[0]}\")\" && pwd)\"\nWORKSPACE=\"$(pwd)\"\nDOCKER_BINARY=\"docker\"\nDOCKER_ENV=\"-e ENV_USER_ID=$(id -u) -e ENV_GROUP_ID=$(id -g)\"\nDOCKER_VOLUMNS=\"-v ${WORKSPACE}:/workspace -v ${SCRIPT_DIR}:/docker\"\n\nshift 1\nwhile [[ $# -gt 0 ]]; do\n    cmd=\"$1\"\n    if [[ $cmd == \"-e\" ]]; then\n        env_key=$2\n        env_value=$3\n        shift 3\n        DOCKER_ENV=\"${DOCKER_ENV} -e ${env_key}=${env_value}\"\n    elif [[ $cmd == \"-v\" ]]; then\n        volumn_key=$2\n        volumn_value=$3\n        shift 3\n        DOCKER_VOLUMNS=\"${DOCKER_VOLUMNS} -v ${volumn_key}:${volumn_value}\"\n    elif [[ $cmd == \"-j\" ]]; then\n        num_threads=$2\n        shift 2\n        DOCKER_ENV=\"${DOCKER_ENV} -e NUM_THREADS=${num_threads} --cpus ${num_threads}\"\n    else\n        break\n    fi\ndone\n\nif [ \"$#\" -eq 0 ]; then\n    COMMAND=\"bash\"\n    if [[ $(uname) == \"Darwin\" ]]; then\n        # Docker's host networking driver isn't supported on macOS.\n        # Use default bridge network and expose port for jupyter notebook.\n        DOCKER_EXTRA_PARAMS=(\"-it -p 8888:8888\")\n    else\n        DOCKER_EXTRA_PARAMS=(\"-it --net=host\")\n    fi\nelse\n    COMMAND=(\"$@\")\nfi\n\nif [[ -n ${MLC_CI_SETUP_DEPS:-} ]]; then\n    DOCKER_ENV=\"${DOCKER_ENV} -e MLC_CI_SETUP_DEPS=${MLC_CI_SETUP_DEPS}\"\nfi\n\n# Use nvidia-docker if the container is GPU.\nif [[ -n ${CUDA_VISIBLE_DEVICES:-} ]]; then\n    DOCKER_ENV=\"${DOCKER_ENV} -e CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES}\"\n    if type nvidia-docker 1> /dev/null 2> /dev/null; then\n        DOCKER_BINARY=nvidia-docker\n    else\n        DOCKER_BINARY=docker\n        DOCKER_ENV=\"${DOCKER_ENV} --gpus all\"\n    fi\n\n    # nvidia-docker treats Vulkan as a graphics API, so we need to\n    # request passthrough of graphics APIs.  This could also be set in\n    # the Dockerfile.\n    DOCKER_ENV=\"${DOCKER_ENV} -e NVIDIA_DRIVER_CAPABILITIES=compute,graphics,utility\"\n\n    # vulkan comaptibility\n    ICD_SEARCH_LOCATIONS=(\n        # https://github.com/KhronosGroup/Vulkan-Loader/blob/master/loader/LoaderAndLayerInterface.md#icd-discovery-on-linux\n        /usr/local/etc/vulkan/icd.d\n        /usr/local/share/vulkan/icd.d\n        /etc/vulkan/icd.d\n        /usr/share/vulkan/icd.d\n        /etc/glvnd/egl_vendor.d\n        /usr/share/glvnd/egl_vendor.d\n    )\n    for filename in $(find \"${ICD_SEARCH_LOCATIONS[@]}\" -name \"*nvidia*.json\" 2> /dev/null); do\n    DOCKER_VOLUMNS=\"${DOCKER_VOLUMNS} -v ${filename}:${filename}:ro\"\n    done\nfi\n\n# Print arguments.\necho \"DOCKER_BINARY ${DOCKER_BINARY}\"\necho \"WORKSPACE: ${WORKSPACE}\"\necho \"IMAGE NAME: ${DOCKER_IMAGE_NAME}\"\necho \"ENV VARIABLES: ${DOCKER_ENV}\"\necho \"VOLUMES: ${DOCKER_VOLUMNS}\"\necho \"COMMANDS: '${COMMAND[@]}'\"\n\n# By default we cleanup - remove the container once it finish running (--rm)\n# and share the PID namespace (--pid=host) so the process inside does not have\n# pid 1 and SIGKILL is propagated to the process inside (jenkins can kill it).\n\n${DOCKER_BINARY} run --rm --pid=host \\\n    -w /workspace \\\n    ${DOCKER_VOLUMNS} \\\n    ${DOCKER_ENV} \\\n    ${DOCKER_EXTRA_PARAMS[@]} \\\n    ${DOCKER_IMAGE_NAME} \\\n    ${COMMAND[@]}\n"
  },
  {
    "path": "ci/build-environment.yaml",
    "content": "name: mlc-llm-build\n\nchannels:\n  - conda-forge\n\ndependencies:\n  - conda-build\n  - anaconda-client\n  - libvulkan-headers\n  - libvulkan-loader\n  - spirv-tools\n  - spirv-headers\n  - git\n  - cmake<4.0\n  - bzip2\n"
  },
  {
    "path": "ci/jenkinsfile.groovy",
    "content": "// Licensed to the Apache Software Foundation (ASF) under one\n// or more contributor license agreements.  See the NOTICE file\n// distributed with this work for additional information\n// regarding copyright ownership.  The ASF licenses this file\n// to you under the Apache License, Version 2.0 (the\n// \"License\"); you may not use this file except in compliance\n// with the License.  You may obtain a copy of the License at\n//\n//   http://www.apache.org/licenses/LICENSE-2.0\n//\n// Unless required by applicable law or agreed to in writing,\n// software distributed under the License is distributed on an\n// \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n// KIND, either express or implied.  See the License for the\n// specific language governing permissions and limitations\n// under the License.\n\nimport org.jenkinsci.plugins.pipeline.modeldefinition.Utils\n\nrun_cpu = \"bash ci/bash.sh mlcaidev/ci-cpu:26d65cc -e GPU cpu -e MLC_CI_SETUP_DEPS 1\"\nrun_cuda = \"bash ci/bash.sh mlcaidev/ci-cu128:26d65cc -e GPU cuda-12.8 -e MLC_CI_SETUP_DEPS 1\"\n// run_rocm = \"bash ci/bash.sh mlcaidev/ci-rocm57:26d65cc -e GPU rocm-5.7 -e MLC_CI_SETUP_DEPS 1\"\n\npkg_cpu = \"bash ci/bash.sh mlcaidev/package-rocm61:519d0b3 -e GPU cpu -e MLC_CI_SETUP_DEPS 1\"\npkg_cuda = \"bash ci/bash.sh mlcaidev/package-cu128:519d0b3 -e GPU cuda-12.8 -e MLC_CI_SETUP_DEPS 1\"\npkg_rocm = \"bash ci/bash.sh mlcaidev/package-rocm61:519d0b3 -e GPU rocm-6.1 -e MLC_CI_SETUP_DEPS 1\"\n\n\ndef per_exec_ws(folder) {\n  return \"workspace/exec_${env.EXECUTOR_NUMBER}/\" + folder\n}\n\ndef pack_lib(name, libs) {\n  sh \"\"\"\n     echo \"Packing ${libs} into ${name}\"\n     echo ${libs} | sed -e 's/,/ /g' | xargs md5sum\n     \"\"\"\n  stash includes: libs, name: name\n}\n\ndef unpack_lib(name, libs) {\n  unstash name\n  sh \"\"\"\n     echo \"Unpacked ${libs} from ${name}\"\n     echo ${libs} | sed -e 's/,/ /g' | xargs md5sum\n     \"\"\"\n}\n\ndef init_git(submodule = false) {\n  cleanWs()\n  // add retry in case checkout timeouts\n  retry(5) {\n    checkout scm\n  }\n  if (submodule) {\n    retry(5) {\n      timeout(time: 10, unit: 'MINUTES') {\n        sh(script: 'git submodule update --init --recursive -f', label: 'Update git submodules')\n      }\n    }\n  }\n}\n\nstage('Lint') {\n  parallel(\n    'isort': {\n      node('CPU-SMALL') {\n        ws(per_exec_ws('mlc-llm-lint-isort')) {\n          init_git()\n          sh(script: \"ls -alh\", label: 'Show work directory')\n          sh(script: \"${run_cpu} conda env export --name ci-lint\", label: 'Checkout version')\n          sh(script: \"${run_cpu} -j 1 conda run -n ci-lint ci/task/isort.sh\", label: 'Lint')\n        }\n      }\n    },\n    'black': {\n      node('CPU-SMALL') {\n        ws(per_exec_ws('mlc-llm-lint-black')) {\n          init_git()\n          sh(script: \"ls -alh\", label: 'Show work directory')\n          sh(script: \"${run_cpu} conda env export --name ci-lint\", label: 'Checkout version')\n          sh(script: \"${run_cpu} -j 1 conda run -n ci-lint ci/task/black.sh\", label: 'Lint')\n        }\n      }\n    },\n    'mypy': {\n      node('CPU-SMALL') {\n        ws(per_exec_ws('mlc-llm-lint-mypy')) {\n          init_git()\n          sh(script: \"ls -alh\", label: 'Show work directory')\n          sh(script: \"${run_cpu} conda env export --name ci-lint\", label: 'Checkout version')\n          sh(script: \"${run_cpu} -j 1 conda run -n ci-lint ci/task/mypy.sh\", label: 'Lint')\n        }\n      }\n    },\n    'pylint': {\n      node('CPU-SMALL') {\n        ws(per_exec_ws('mlc-llm-lint-pylint')) {\n          init_git()\n          sh(script: \"ls -alh\", label: 'Show work directory')\n          sh(script: \"${run_cpu} conda env export --name ci-lint\", label: 'Checkout version')\n          sh(script: \"${run_cpu} -j 4 conda run -n ci-lint ci/task/pylint.sh\", label: 'Lint')\n        }\n      }\n    },\n    'clang-format': {\n      node('CPU-SMALL') {\n        ws(per_exec_ws('mlc-llm-lint-clang-format')) {\n          init_git()\n          sh(script: \"ls -alh\", label: 'Show work directory')\n          sh(script: \"${run_cpu} conda env export --name ci-lint\", label: 'Checkout version')\n          sh(script: \"${run_cpu} -j 1 conda run -n ci-lint ci/task/clang-format.sh\", label: 'Lint')\n        }\n      }\n    },\n  )\n}\n\nstage('Build') {\n  parallel(\n    'CUDA': {\n      node('CPU-SMALL') {\n        ws(per_exec_ws('mlc-llm-build-cuda')) {\n          init_git(true)\n          sh(script: \"ls -alh\", label: 'Show work directory')\n          sh(script: \"${pkg_cuda} conda env export --name py312\", label: 'Checkout version')\n          sh(script: \"${pkg_cuda} -j 8 -v \\$HOME/.ccache /ccache conda run -n py312 ./ci/task/build_lib.sh\", label: 'Build MLC LLM runtime')\n          sh(script: \"${pkg_cuda} -j 1 conda run -n py312 ./ci/task/build_clean.sh\", label: 'Clean up after build')\n          sh(script: \"ls -alh ./wheels/\", label: 'Build artifact')\n          pack_lib('mlc_wheel_cuda', 'wheels/*.whl')\n        }\n      }\n    },\n    // 'ROCm': {\n    //   node('CPU-SMALL') {\n    //     ws(per_exec_ws('mlc-llm-build-rocm')) {\n    //       init_git(true)\n    //       sh(script: \"ls -alh\", label: 'Show work directory')\n    //       sh(script: \"${pkg_rocm} conda env export --name py38\", label: 'Checkout version')\n    //       sh(script: \"${pkg_rocm} -j 8 conda run -n py38 ./ci/task/build_lib.sh\", label: 'Build MLC LLM runtime')\n    //       sh(script: \"${pkg_rocm} -j 1 conda run -n py38 ./ci/task/build_clean.sh\", label: 'Clean up after build')\n    //       sh(script: \"ls -alh ./wheels/\", label: 'Build artifact')\n    //       pack_lib('mlc_wheel_rocm', 'wheels/*.whl')\n    //     }\n    //   }\n    // },\n    'Metal': {\n      node('MAC') {\n        ws(per_exec_ws('mlc-llm-build-metal')) {\n          init_git(true)\n          sh(script: \"ls -alh\", label: 'Show work directory')\n          sh(script: \"conda env export --name mlc-llm-ci\", label: 'Checkout version')\n          sh(script: \"NUM_THREADS=6 GPU=metal conda run -n mlc-llm-ci ./ci/task/build_lib.sh\", label: 'Build MLC LLM runtime')\n          sh(script: \"NUM_THREADS=6 GPU=metal conda run -n mlc-llm-ci ./ci/task/build_clean.sh\", label: 'Clean up after build')\n          sh(script: \"ls -alh ./wheels/\", label: 'Build artifact')\n          pack_lib('mlc_wheel_metal', 'wheels/*.whl')\n        }\n      }\n    },\n    'Vulkan': {\n      node('CPU-SMALL') {\n        ws(per_exec_ws('mlc-llm-build-vulkan')) {\n          init_git(true)\n          sh(script: \"ls -alh\", label: 'Show work directory')\n          sh(script: \"${pkg_cpu} conda env export --name py312\", label: 'Checkout version')\n          sh(script: \"${pkg_cpu} -j 8 conda run -n py312 ./ci/task/build_lib.sh\", label: 'Build MLC LLM runtime')\n          sh(script: \"${pkg_cpu} -j 1 conda run -n py312 ./ci/task/build_clean.sh\", label: 'Clean up after build')\n          sh(script: \"ls -alh ./wheels/\", label: 'Build artifact')\n          pack_lib('mlc_wheel_vulkan', 'wheels/*.whl')\n        }\n      }\n    }\n  )\n}\n\nstage('Unittest') {\n  parallel(\n    'CUDA': {\n      node('GPU') {\n        ws(per_exec_ws('mlc-llm-unittest')) {\n          init_git(false)\n          sh(script: \"ls -alh\", label: 'Show work directory')\n          unpack_lib('mlc_wheel_cuda', 'wheels/*.whl')\n          sh(script: \"${run_cuda} conda env export --name ci-unittest\", label: 'Checkout version')\n          sh(script: \"${run_cuda} conda run -n ci-unittest ./ci/task/test_unittest.sh\", label: 'Testing')\n        }\n      }\n    }\n  )\n}\n\nstage('Model Compilation') {\n  parallel(\n    'CUDA': {\n      node('CPU-SMALL') {\n        ws(per_exec_ws('mlc-llm-compile-cuda')) {\n          init_git(false)\n          sh(script: \"ls -alh\", label: 'Show work directory')\n          unpack_lib('mlc_wheel_cuda', 'wheels/*.whl')\n          sh(script: \"${run_cuda} conda env export --name ci-unittest\", label: 'Checkout version')\n          sh(script: \"${run_cuda} -j 4 conda run -n ci-unittest ./ci/task/test_model_compile.sh\", label: 'Testing')\n        }\n      }\n    },\n    // 'ROCm': {\n    //   node('CPU-SMALL') {\n    //     ws(per_exec_ws('mlc-llm-compile-rocm')) {\n    //       init_git(false)\n    //       sh(script: \"ls -alh\", label: 'Show work directory')\n    //       unpack_lib('mlc_wheel_rocm', 'wheels/*.whl')\n    //       sh(script: \"${run_rocm} conda env export --name ci-unittest\", label: 'Checkout version')\n    //       sh(script: \"${run_rocm} -j 4 conda run -n ci-unittest ./ci/task/test_model_compile.sh\", label: 'Testing')\n    //     }\n    //   }\n    // },\n    'Metal': {\n      node('MAC') {\n        ws(per_exec_ws('mlc-llm-compile-metal')) {\n          init_git(false)\n          sh(script: \"ls -alh\", label: 'Show work directory')\n          unpack_lib('mlc_wheel_metal', 'wheels/*.whl')\n          sh(script: \"conda env export --name mlc-llm-ci\", label: 'Checkout version')\n          sh(script: \"NUM_THREADS=6 GPU=metal conda run -n mlc-llm-ci ./ci/task/test_model_compile.sh\", label: 'Testing')\n        }\n      }\n    },\n    'Vulkan': {\n      node('CPU-SMALL') {\n        ws(per_exec_ws('mlc-llm-compile-vulkan')) {\n          init_git(false)\n          sh(script: \"ls -alh\", label: 'Show work directory')\n          unpack_lib('mlc_wheel_vulkan', 'wheels/*.whl')\n          sh(script: \"${run_cpu} conda env export --name ci-unittest\", label: 'Checkout version')\n          // sh(script: \"${run_cpu} -j 4 conda run -n ci-unittest ./ci/task/test_model_compile.sh\", label: 'Testing')\n        }\n      }\n    },\n    'WASM': {\n      node('CPU-SMALL') {\n        ws(per_exec_ws('mlc-llm-compile-wasm')) {\n          init_git(true)\n          sh(script: \"ls -alh\", label: 'Show work directory')\n          unpack_lib('mlc_wheel_vulkan', 'wheels/*.whl')\n          sh(script: \"${run_cpu} conda env export --name ci-unittest\", label: 'Checkout version')\n          sh(script: \"${run_cpu} -j 8 -e GPU wasm conda run -n ci-unittest ./ci/task/test_model_compile.sh\", label: 'Testing')\n        }\n      }\n    },\n    'iOS': {\n      node('MAC') {\n        ws(per_exec_ws('mlc-llm-compile-ios')) {\n          init_git(false)\n          sh(script: \"ls -alh\", label: 'Show work directory')\n          unpack_lib('mlc_wheel_metal', 'wheels/*.whl')\n          sh(script: \"conda env export --name mlc-llm-ci\", label: 'Checkout version')\n          sh(script: \"NUM_THREADS=6 GPU=ios conda run -n mlc-llm-ci ./ci/task/test_model_compile.sh\", label: 'Testing')\n        }\n      }\n    },\n    'Android-OpenCL': {\n      node('CPU-SMALL') {\n        ws(per_exec_ws('mlc-llm-compile-android')) {\n          init_git(false)\n          sh(script: \"ls -alh\", label: 'Show work directory')\n          unpack_lib('mlc_wheel_vulkan', 'wheels/*.whl')\n          sh(script: \"${run_cpu} conda env export --name ci-unittest\", label: 'Checkout version')\n          sh(script: \"${run_cpu} -j 4 -e GPU android conda run -n ci-unittest ./ci/task/test_model_compile.sh\", label: 'Testing')\n        }\n      }\n    }\n  )\n}\n"
  },
  {
    "path": "ci/task/black.sh",
    "content": "#!/bin/bash\nset -eo pipefail\nset -x\n: ${NUM_THREADS:=$(nproc)}\n: ${WORKSPACE_CWD:=$(pwd)}\n: ${GPU:=\"cpu\"}\n\nblack --diff --check --workers $NUM_THREADS \\\n    ./python/ \\\n    ./tests/python \\\n    ./examples/python\n"
  },
  {
    "path": "ci/task/build_clean.sh",
    "content": "#!/bin/bash\nset -eo pipefail\nset -x\n: ${NUM_THREADS:=$(nproc)}\n: ${WORKSPACE_CWD:=$(pwd)}\n: ${GPU:=\"cpu\"}\n\nrm -rf ${WORKSPACE_CWD}/build/ \\\n    ${WORKSPACE_CWD}/dist/\n"
  },
  {
    "path": "ci/task/build_lib.sh",
    "content": "#!/bin/bash\nset -eo pipefail\nset -x\n: ${NUM_THREADS:=$(nproc)}\n: ${WORKSPACE_CWD:=$(pwd)}\n: ${GPU:=\"cpu\"}\nexport CCACHE_COMPILERCHECK=content\nexport CCACHE_NOHASHDIR=1\nexport CCACHE_DIR=/ccache\n\n# Temporary workaround to install ccache.\nif [[ ${GPU} != metal ]]; then\n    conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main\n    conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r\nfi\nconda install -c conda-forge ccache\n\nif [[ ${GPU} != metal ]]; then\n    source /multibuild/manylinux_utils.sh\n    source /opt/rh/gcc-toolset-11/enable # GCC-11 is the hightest GCC version compatible with NVCC < 12\nfi\n\nmkdir -p $WORKSPACE_CWD/build\nif [[ ${GPU} == rocm* ]]; then\n    echo set\\(USE_VULKAN ON\\) >>config.cmake\n    echo set\\(USE_ROCM ON\\) >>config.cmake\n    echo set\\(USE_RCCL /opt/rocm/rccl/ \\) >>config.cmake\nelif [[ ${GPU} == cuda* ]]; then\n    echo set\\(USE_VULKAN ON\\) >>config.cmake\n    echo set\\(CMAKE_CUDA_COMPILER_LAUNCHER ccache\\) >>config.cmake\n    echo set\\(CMAKE_CUDA_ARCHITECTURES \"80;90;100;120\"\\) >>config.cmake\n    echo set\\(CMAKE_CUDA_FLAGS \\\"\\$\\{CMAKE_CUDA_FLAGS\\} -t $NUM_THREADS\\\"\\) >>config.cmake\n    echo set\\(USE_CUDA ON\\) >>config.cmake\n    echo set\\(USE_CUBLAS ON\\) >>config.cmake\n    echo set\\(USE_NCCL ON\\) >>config.cmake\nelif [[ ${GPU} == metal ]]; then\n    export CCACHE_DIR=$HOME/ci/ccache\n    echo set\\(USE_METAL ON\\) >>config.cmake\nelse\n    echo set\\(USE_VULKAN ON\\) >>config.cmake\nfi\n\ncat config.cmake\n\nAUDITWHEEL_OPTS=\"--plat ${AUDITWHEEL_PLAT} -w repaired_wheels/\"\nAUDITWHEEL_OPTS=\"--exclude libtvm --exclude libtvm_runtime --exclude libtvm_ffi --exclude libvulkan ${AUDITWHEEL_OPTS}\"\nif [[ ${GPU} == rocm* ]]; then\n    AUDITWHEEL_OPTS=\"--exclude libamdhip64 --exclude libhsa-runtime64 --exclude librocm_smi64 --exclude librccl ${AUDITWHEEL_OPTS}\"\nelif [[ ${GPU} == cuda* ]]; then\n    AUDITWHEEL_OPTS=\"--exclude libcuda --exclude libcudart --exclude libnvrtc --exclude libcublas --exclude libcublasLt ${AUDITWHEEL_OPTS}\"\nfi\n\nrm -rf ${WORKSPACE_CWD}/dist\ncd ${WORKSPACE_CWD} && pip wheel --no-deps -w dist . -v\n\nrm -rf ${WORKSPACE_CWD}/wheels/\nif [[ ${GPU} != metal ]]; then\n    mkdir -p ${WORKSPACE_CWD}/repaired_wheels\n    rm -rf ${WORKSPACE_CWD}/repaired_wheels/*\n    auditwheel repair ${AUDITWHEEL_OPTS} dist/*.whl\n    mv ${WORKSPACE_CWD}/repaired_wheels/ ${WORKSPACE_CWD}/wheels/\nelse\n    mkdir ${WORKSPACE_CWD}/wheels/\n    mv dist/*.whl ${WORKSPACE_CWD}/wheels/\nfi\n\nchown -R $ENV_USER_ID:$ENV_GROUP_ID ${WORKSPACE_CWD}/wheels/\n"
  },
  {
    "path": "ci/task/build_win.bat",
    "content": "cd mlc-llm\nrd /s /q build\nmkdir build\n\necho set(USE_VULKAN ON) >> config.cmake\n\npip install . -v\n\nif %errorlevel% neq 0 exit %errorlevel%\n"
  },
  {
    "path": "ci/task/clang-format.sh",
    "content": "#!/bin/bash\nset -eo pipefail\nset -x\n: ${NUM_THREADS:=$(nproc)}\n: ${WORKSPACE_CWD:=$(pwd)}\n: ${GPU:=\"cpu\"}\n\nINPLACE_FORMAT=${INPLACE_FORMAT:=false}\nLINT_ALL_FILES=true\nREVISION=$(git rev-list --max-parents=0 HEAD)\n\nwhile (($#)); do\n    case \"$1\" in\n    -i)\n        INPLACE_FORMAT=true\n        shift 1\n        ;;\n    --rev)\n        LINT_ALL_FILES=false\n        REVISION=$2\n        shift 2\n        ;;\n    *)\n        echo \"Usage: clang-format.sh [-i] [--rev <commit>]\"\n        echo \"\"\n        echo \"Run clang-format on files that changed since <commit> or on all files in the repo\"\n        echo \"Examples:\"\n        echo \"- Compare last one commit: clang-format.sh --rev HEAD~1\"\n        echo \"- Compare against upstream/main: clang-format.sh --rev upstream/main\"\n        echo \"The -i will format files in-place instead of checking them.\"\n        exit 1\n        ;;\n    esac\ndone\n\ncleanup() {\n    if [ -f /tmp/$$.clang-format.txt ]; then\n        echo \"\"\n        echo \"---------clang-format log----------\"\n        cat /tmp/$$.clang-format.txt\n    fi\n    rm -rf /tmp/$$.clang-format.txt\n}\ntrap cleanup 0\n\nif [[ \"$INPLACE_FORMAT\" == \"true\" ]]; then\n    echo \"Running inplace git-clang-format against $REVISION\"\n    git-clang-format --extensions h,hh,hpp,c,cc,cpp,mm \"$REVISION\"\n    exit 0\nfi\n\nif [[ \"$LINT_ALL_FILES\" == \"true\" ]]; then\n    echo \"Running git-clang-format against all C++ files\"\n    git-clang-format --diff --extensions h,hh,hpp,c,cc,cpp,mm \"$REVISION\" 1>/tmp/$$.clang-format.txt\nelse\n    echo \"Running git-clang-format against $REVISION\"\n    git-clang-format --diff --extensions h,hh,hpp,c,cc,cpp,mm \"$REVISION\" 1>/tmp/$$.clang-format.txt\nfi\n\nif grep --quiet -E \"diff\" </tmp/$$.clang-format.txt; then\n    echo \"clang-format lint error found. Consider running clang-format on these files to fix them.\"\n    exit 1\nfi\n"
  },
  {
    "path": "ci/task/isort.sh",
    "content": "#!/bin/bash\nset -eo pipefail\nset -x\n: ${NUM_THREADS:=$(nproc)}\n: ${WORKSPACE_CWD:=$(pwd)}\n: ${GPU:=\"cpu\"}\n\nisort --check-only -j $NUM_THREADS --profile black \\\n    ./python/ \\\n    ./tests/python/ \\\n    ./examples/python\n"
  },
  {
    "path": "ci/task/mypy.sh",
    "content": "#!/bin/bash\n: ${NUM_THREADS:=$(nproc)}\n: ${WORKSPACE_CWD:=$(pwd)}\n: ${GPU:=\"cpu\"}\n\nmypy --install-types --non-interactive ./python/ ./tests/python/ ./examples/python/\n"
  },
  {
    "path": "ci/task/pylint.sh",
    "content": "#!/bin/bash\nset -eo pipefail\nset -x\n: ${NUM_THREADS:=$(nproc)}\n: ${WORKSPACE_CWD:=$(pwd)}\n: ${GPU:=\"cpu\"}\nexport PYTHONPATH=\"./python\":${PYTHONPATH:-\"\"}\n\nif [[ -n ${MLC_CI_SETUP_DEPS:-} ]]; then\n    echo \"MLC_CI_SETUP_DEPS=1 start setup deps\"\n    # TVM Unity is a dependency to this testing\n    pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cpu\n    pip install apache-tvm-ffi\n    pip install requests triton\n    pip install --pre -U cuda-python\nfi\n\npylint --jobs $NUM_THREADS ./python/\npylint --jobs $NUM_THREADS --recursive=y ./tests/python/\npylint --jobs $NUM_THREADS --recursive=y ./examples/python/\n"
  },
  {
    "path": "ci/task/test_model_compile.sh",
    "content": "#!/bin/bash\nset -eo pipefail\nset -x\n: ${NUM_THREADS:=$(nproc)}\n: ${WORKSPACE_CWD:=$(pwd)}\n: ${GPU:=\"cpu\"}\n\npip install --force-reinstall wheels/*.whl\n\nif [[ ${GPU} == cuda* ]]; then\n    TARGET=cuda\n    pip install --pre -U --no-index -f https://mlc.ai/wheels mlc-ai-nightly-cu128\n    export LD_LIBRARY_PATH=/usr/local/cuda/compat/:$LD_LIBRARY_PATH\nelif [[ ${GPU} == rocm* ]]; then\n    TARGET=rocm\n    pip install --pre -U --no-index -f https://mlc.ai/wheels mlc-ai-nightly-rocm57\nelif [[ ${GPU} == metal ]]; then\n    TARGET=metal\n    pip install --pre -U --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-cpu\nelif [[ ${GPU} == wasm* ]]; then\n    TARGET=wasm\n    # Clone a copy a tvm source code to build tvm web runtime\n    git clone https://github.com/mlc-ai/relax.git /tmp/tvm --recursive\n    export TVM_SOURCE_DIR=/tmp/tvm\n    # Pip install tvm so that `import tvm` in Python works\n    pip install --pre -U --no-index -f https://mlc.ai/wheels mlc-ai-nightly-cpu\n    export TVM_HOME=${TVM_SOURCE_DIR}\n    export MLC_LLM_SOURCE_DIR=$(pwd)\n    cd $TVM_SOURCE_DIR/web/ && make -j${NUM_THREADS} && cd -\n    cd $MLC_LLM_SOURCE_DIR/web/ && make -j${NUM_THREADS} && cd -\nelif [[ ${GPU} == ios ]]; then\n    TARGET=ios\n    pip install --pre -U --force-reinstall -f https://mlc.ai/wheels mlc-ai-nightly-cpu\nelif [[ ${GPU} == android* ]]; then\n    TARGET=android\n    pip install --pre -U --no-index -f https://mlc.ai/wheels mlc-ai-nightly-cpu\nelse\n    TARGET=vulkan\n    pip install --pre -U --no-index -f https://mlc.ai/wheels mlc-ai-nightly-cpu\nfi\n\npython tests/python/integration/test_model_compile.py $TARGET $NUM_THREADS\n"
  },
  {
    "path": "ci/task/test_unittest.sh",
    "content": "#!/bin/bash\nset -eo pipefail\nset -x\n\n# this scripts only triggers in CI_ENV where these environment variable are passed\nif [[ -n ${MLC_CI_SETUP_DEPS:-} ]]; then\n    echo \"MLC_CI_SETUP_DEPS=1 start setup deps..\"\n    # Install dependency\n    pip install --force-reinstall wheels/*.whl\n    pip install --quiet pytest\n    pip install --pre -U --no-index -f https://mlc.ai/wheels mlc-ai-nightly-cu128\n    export LD_LIBRARY_PATH=/usr/local/cuda/compat/:$LD_LIBRARY_PATH\nfi\n\n# run all tests that are categorized as \"unittest\"\n# add pytestmarker = [pytest.mark.unittest] in the test file\n# so they will be run here\npython -m pytest -v tests/python/ -m unittest\n"
  },
  {
    "path": "cmake/gen_cmake_config.py",
    "content": "from collections import namedtuple\n\nBackend = namedtuple(\"Backend\", [\"name\", \"cmake_config_name\", \"prompt_str\", \"parent\"])\n\nif __name__ == \"__main__\":\n    tvm_home = \"\"  # pylint: disable=invalid-name\n\n    tvm_home = input(\n        \"Enter TVM_SOURCE_DIR in absolute path. If not specified, 3rdparty/tvm will be used by default: \"\n    )\n    if len(tvm_home) == 0:\n        tvm_home = \"3rdparty/tvm\"  # pylint: disable=invalid-name\n\n    cmake_config_str = f\"set(TVM_SOURCE_DIR {tvm_home})\\n\"\n    cmake_config_str += \"set(CMAKE_BUILD_TYPE RelWithDebInfo)\\n\"\n    cuda_backend = Backend(\"CUDA\", \"USE_CUDA\", \"Use CUDA? (y/n): \", None)\n    opencl_backend = Backend(\"OpenCL\", \"USE_OPENCL\", \"Use OpenCL? (y/n) \", None)\n    backends = [\n        cuda_backend,\n        Backend(\"CUTLASS\", \"USE_CUTLASS\", \"Use CUTLASS? (y/n): \", cuda_backend),\n        Backend(\"CUBLAS\", \"USE_CUBLAS\", \"Use CUBLAS? (y/n): \", cuda_backend),\n        Backend(\"ROCm\", \"USE_ROCM\", \"Use ROCm? (y/n): \", None),\n        Backend(\"Vulkan\", \"USE_VULKAN\", \"Use Vulkan? (y/n): \", None),\n        Backend(\"Metal\", \"USE_METAL\", \"Use Metal (Apple M1/M2 GPU) ? (y/n): \", None),\n        opencl_backend,\n        Backend(\n            \"OpenCLHostPtr\",\n            \"USE_OPENCL_ENABLE_HOST_PTR\",\n            \"Use OpenCLHostPtr? (y/n): \",\n            opencl_backend,\n        ),\n    ]\n\n    enabled_backends = set()\n\n    for backend in backends:\n        if backend.parent is not None and backend.parent.name not in enabled_backends:\n            cmake_config_str += f\"set({backend.cmake_config_name} OFF)\\n\"\n        else:\n            while True:\n                use_backend = input(backend.prompt_str)\n                if use_backend in [\"yes\", \"Y\", \"y\"]:\n                    cmake_config_str += f\"set({backend.cmake_config_name} ON)\\n\"\n                    enabled_backends.add(backend.name)\n                    break\n                elif use_backend in [\"no\", \"N\", \"n\"]:\n                    cmake_config_str += f\"set({backend.cmake_config_name} OFF)\\n\"\n                    break\n                else:\n                    print(f\"Invalid input: {use_backend}. Please input again.\")\n\n    if \"CUDA\" in enabled_backends:\n        cmake_config_str += f\"set(USE_THRUST ON)\\n\"\n\n    print(\"\\nWriting the following configuration to config.cmake...\")\n    print(cmake_config_str)\n\n    with open(\"config.cmake\", \"w\") as f:\n        f.write(cmake_config_str)\n"
  },
  {
    "path": "cpp/base.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file base.h\n */\n\n#ifndef MLC_LLM_DLL\n#ifdef _WIN32\n#ifdef MLC_LLM_EXPORTS\n#define MLC_LLM_DLL __declspec(dllexport)\n#else\n#define MLC_LLM_DLL __declspec(dllimport)\n#endif\n#else\n#define MLC_LLM_DLL __attribute__((visibility(\"default\")))\n#endif\n#endif\n"
  },
  {
    "path": "cpp/json_ffi/conv_template.cc",
    "content": "#include \"conv_template.h\"\n\n#include <tvm/ffi/function.h>\n\n#include \"../support/json_parser.h\"\n#include \"image_utils.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace json_ffi {\n\nusing namespace mlc::llm;\n\n/****************** Model vision config ******************/\n\nModelVisionConfig ModelVisionConfig::FromJSON(const tvm::ffi::json::Object& json_obj) {\n  ModelVisionConfig config;\n\n  Result<int64_t> hidden_size_res = json::LookupWithResultReturn<int64_t>(json_obj, \"hidden_size\");\n  if (hidden_size_res.IsOk()) {\n    config.hidden_size = static_cast<int>(hidden_size_res.Unwrap());\n  }\n\n  Result<int64_t> image_size_res = json::LookupWithResultReturn<int64_t>(json_obj, \"image_size\");\n  if (image_size_res.IsOk()) {\n    config.image_size = static_cast<int>(image_size_res.Unwrap());\n  }\n\n  Result<int64_t> intermediate_size_res =\n      json::LookupWithResultReturn<int64_t>(json_obj, \"intermediate_size\");\n  if (intermediate_size_res.IsOk()) {\n    config.intermediate_size = static_cast<int>(intermediate_size_res.Unwrap());\n  }\n\n  Result<int64_t> num_attention_heads_res =\n      json::LookupWithResultReturn<int64_t>(json_obj, \"num_attention_heads\");\n  if (num_attention_heads_res.IsOk()) {\n    config.num_attention_heads = static_cast<int>(num_attention_heads_res.Unwrap());\n  }\n\n  Result<int64_t> num_hidden_layers_res =\n      json::LookupWithResultReturn<int64_t>(json_obj, \"num_hidden_layers\");\n  if (num_hidden_layers_res.IsOk()) {\n    config.num_hidden_layers = static_cast<int>(num_hidden_layers_res.Unwrap());\n  }\n\n  Result<int64_t> patch_size_res = json::LookupWithResultReturn<int64_t>(json_obj, \"patch_size\");\n  if (patch_size_res.IsOk()) {\n    config.patch_size = static_cast<int>(patch_size_res.Unwrap());\n  }\n\n  Result<int64_t> projection_dim_res =\n      json::LookupWithResultReturn<int64_t>(json_obj, \"projection_dim\");\n  if (projection_dim_res.IsOk()) {\n    config.projection_dim = static_cast<int>(projection_dim_res.Unwrap());\n  }\n\n  Result<int64_t> vocab_size_res = json::LookupWithResultReturn<int64_t>(json_obj, \"vocab_size\");\n  if (vocab_size_res.IsOk()) {\n    config.vocab_size = static_cast<int>(vocab_size_res.Unwrap());\n  }\n\n  Result<std::string> dtype_res = json::LookupWithResultReturn<std::string>(json_obj, \"dtype\");\n  if (dtype_res.IsOk()) {\n    config.dtype = dtype_res.Unwrap();\n  }\n\n  Result<int64_t> num_channels_res =\n      json::LookupWithResultReturn<int64_t>(json_obj, \"num_channels\");\n  if (num_channels_res.IsOk()) {\n    config.num_channels = static_cast<int>(num_channels_res.Unwrap());\n  }\n\n  Result<double> layer_norm_eps_res =\n      json::LookupWithResultReturn<double>(json_obj, \"layer_norm_eps\");\n  if (layer_norm_eps_res.IsOk()) {\n    config.layer_norm_eps = layer_norm_eps_res.Unwrap();\n  }\n\n  return config;\n}\n\n/****************** Model config ******************/\n\nModelConfig ModelConfig::FromJSON(const tvm::ffi::json::Object& json_obj) {\n  ModelConfig config;\n\n  Result<int64_t> vocab_size_res = json::LookupWithResultReturn<int64_t>(json_obj, \"vocab_size\");\n  if (vocab_size_res.IsOk()) {\n    config.vocab_size = static_cast<int>(vocab_size_res.Unwrap());\n  }\n\n  Result<int64_t> context_window_size_res =\n      json::LookupWithResultReturn<int64_t>(json_obj, \"context_window_size\");\n  if (context_window_size_res.IsOk()) {\n    config.context_window_size = static_cast<int>(context_window_size_res.Unwrap());\n  }\n\n  Result<int64_t> sliding_window_size_res =\n      json::LookupWithResultReturn<int64_t>(json_obj, \"sliding_window_size\");\n  if (sliding_window_size_res.IsOk()) {\n    config.sliding_window_size = static_cast<int>(sliding_window_size_res.Unwrap());\n  }\n\n  Result<int64_t> prefill_chunk_size_res =\n      json::LookupWithResultReturn<int64_t>(json_obj, \"prefill_chunk_size\");\n  if (prefill_chunk_size_res.IsOk()) {\n    config.prefill_chunk_size = static_cast<int>(prefill_chunk_size_res.Unwrap());\n  }\n\n  Result<int64_t> tensor_parallel_shards_res =\n      json::LookupWithResultReturn<int64_t>(json_obj, \"tensor_parallel_shards\");\n  if (tensor_parallel_shards_res.IsOk()) {\n    config.tensor_parallel_shards = static_cast<int>(tensor_parallel_shards_res.Unwrap());\n  }\n\n  Result<int64_t> pipeline_parallel_stages_res =\n      json::LookupWithResultReturn<int64_t>(json_obj, \"pipeline_parallel_stages\");\n  if (pipeline_parallel_stages_res.IsOk()) {\n    config.pipeline_parallel_stages = static_cast<int>(pipeline_parallel_stages_res.Unwrap());\n  }\n\n  Result<int64_t> max_batch_size_res =\n      json::LookupWithResultReturn<int64_t>(json_obj, \"max_batch_size\");\n  if (max_batch_size_res.IsOk()) {\n    config.max_batch_size = static_cast<int>(max_batch_size_res.Unwrap());\n  }\n\n  if (json_obj.count(\"vision_config\")) {\n    const tvm::ffi::json::Object& vision_config_obj =\n        json_obj.at(\"vision_config\").cast<tvm::ffi::json::Object>();\n    config.vision_config = ModelVisionConfig::FromJSON(vision_config_obj);\n  }\n\n  return config;\n}\n\n/****************** Conversation template ******************/\n\nstd::unordered_map<MessagePlaceholders, std::string> PLACEHOLDERS = {\n    {MessagePlaceholders::SYSTEM, \"{system_message}\"},\n    {MessagePlaceholders::USER, \"{user_message}\"},\n    {MessagePlaceholders::ASSISTANT, \"{assistant_message}\"},\n    {MessagePlaceholders::TOOL, \"{tool_message}\"},\n    {MessagePlaceholders::FUNCTION, \"{function_string}\"}};\n\nMessagePlaceholders MessagePlaceholderFromString(const std::string& role) {\n  static const std::unordered_map<std::string, MessagePlaceholders> enum_map = {\n      {\"system\", MessagePlaceholders::SYSTEM},       {\"user\", MessagePlaceholders::USER},\n      {\"assistant\", MessagePlaceholders::ASSISTANT}, {\"tool\", MessagePlaceholders::TOOL},\n      {\"function\", MessagePlaceholders::FUNCTION},\n  };\n\n  return enum_map.at(role);\n}\n\nConversation::Conversation()\n    : role_templates({{\"user\", PLACEHOLDERS[MessagePlaceholders::USER]},\n                      {\"assistant\", PLACEHOLDERS[MessagePlaceholders::ASSISTANT]},\n                      {\"tool\", PLACEHOLDERS[MessagePlaceholders::TOOL]}}) {}\n\nstd::string Conversation::GetSystemText(const std::string& system_msg) const {\n  std::string system_text = this->system_template;\n  static std::string system_placeholder = PLACEHOLDERS[MessagePlaceholders::SYSTEM];\n  size_t pos = system_text.find(system_placeholder);\n  if (pos != std::string::npos) {\n    system_text.replace(pos, system_placeholder.length(), system_msg);\n  }\n  return system_text;\n}\n\nstd::string Conversation::GetRoleText(const std::string& role, const std::string& content,\n                                      const std::optional<std::string>& fn_call_string) const {\n  std::string role_text = this->role_templates.at(role);\n  std::string placeholder = PLACEHOLDERS[MessagePlaceholderFromString(role)];\n  size_t pos = role_text.find(placeholder);\n  if (pos != std::string::npos) {\n    role_text.replace(pos, placeholder.length(), content);\n  }\n  if (fn_call_string) {\n    // replace placeholder[FUNCTION] with function_string\n    // this assumes function calling is used for a single request scenario only\n    pos = role_text.find(PLACEHOLDERS[MessagePlaceholders::FUNCTION]);\n    if (pos != std::string::npos) {\n      role_text.replace(pos, PLACEHOLDERS[MessagePlaceholders::FUNCTION].length(),\n                        fn_call_string.value());\n    }\n  }\n  return role_text;\n}\n\n/// Try to detect if function calling is needed, if so, return the function calling string\nResult<std::optional<std::string>> TryGetFunctionCallingString(\n    const Conversation& conv, const ChatCompletionRequest& request) {\n  using TResult = Result<std::optional<std::string>>;\n  if (!request.tools.has_value() ||\n      (request.tool_choice.has_value() && request.tool_choice.value() == \"none\")) {\n    return TResult::Ok(std::nullopt);\n  }\n  std::vector<ChatTool> tools_ = request.tools.value();\n  std::string tool_choice_ = request.tool_choice.value();\n\n  // TODO: support with tool choice as dict\n  for (const auto& tool : tools_) {\n    if (tool.function.name == tool_choice_) {\n      tvm::ffi::json::Value function_str(tool.function.AsJSON());\n      return TResult::Ok(tvm::ffi::json::Stringify(function_str));\n    }\n  }\n\n  if (tool_choice_ != \"auto\") {\n    return TResult::Error(\"Invalid tool_choice value in the request: \" + tool_choice_);\n  }\n\n  tvm::ffi::json::Array function_list;\n  for (const auto& tool : tools_) {\n    function_list.push_back(tool.function.AsJSON());\n  }\n\n  tvm::ffi::json::Value function_list_json(function_list);\n  return TResult::Ok(tvm::ffi::json::Stringify(function_list_json));\n};\n\nResult<std::vector<Data>> CreatePrompt(const Conversation& conv,\n                                       const ChatCompletionRequest& request,\n                                       const ModelConfig& config, DLDevice device) {\n  using TResult = Result<std::vector<Data>>;\n\n  Result<std::optional<std::string>> fn_call_str_tmp = TryGetFunctionCallingString(conv, request);\n  if (fn_call_str_tmp.IsErr()) {\n    return TResult::Error(fn_call_str_tmp.UnwrapErr());\n  }\n  std::optional<std::string> fn_call_string = fn_call_str_tmp.Unwrap();\n\n  // Handle system message\n  // concz\n  bool has_custom_system = false;\n  std::string custom_system_inputs;\n\n  auto f_populate_system_message = [&](const std::vector<ChatCompletionMessage>& msg_vec) {\n    for (ChatCompletionMessage msg : msg_vec) {\n      if (msg.role == \"system\") {\n        TVM_FFI_ICHECK(msg.content.IsText()) << \"System message must be text\";\n        custom_system_inputs += msg.content.Text();\n        has_custom_system = true;\n      }\n    }\n  };\n  // go through messages in template and passed in.\n  f_populate_system_message(conv.messages);\n  f_populate_system_message(request.messages);\n\n  // pending text records the text to be put into data\n  // we lazily accumulate the pending text\n  // to reduce amount of segments in the Data vector\n  std::string pending_text =\n      conv.GetSystemText(has_custom_system ? custom_system_inputs : conv.system_message);\n\n  // Get the message strings\n  std::vector<Data> message_list;\n  size_t non_system_msg_count = 0;\n\n  // returns error if error happens\n  auto f_process_messages =\n      [&](const std::vector<ChatCompletionMessage>& msg_vec) -> std::optional<TResult> {\n    for (size_t i = 0; i < msg_vec.size(); ++i) {\n      const ChatCompletionMessage& msg = msg_vec[i];\n      // skip system message as it is already processed\n      if (msg.role == \"system\") continue;\n\n      auto role_it = conv.roles.find(msg.role);\n      if (role_it == conv.roles.end()) {\n        return TResult::Error(\"Role \\\"\" + msg.role + \"\\\" is not supported\");\n      }\n      const std::string& role_name = role_it->second;\n      // skip when content is empty\n      if (msg.content.IsNull()) {\n        pending_text += role_name + conv.role_empty_sep;\n        continue;\n      }\n      ++non_system_msg_count;\n      // assistant uses conv.seps[1] if there are two seps\n      int sep_offset = msg.role == \"assistant\" ? 1 : 0;\n      const std::string& seperator = conv.seps[sep_offset % conv.seps.size()];\n      // setup role prefix\n      std::string role_prefix = \"\";\n      // Do not append role prefix if this is the first message and there is already a system\n      // message\n      if (conv.add_role_after_system_message || pending_text.empty() || non_system_msg_count != 1) {\n        role_prefix = role_name + conv.role_content_sep;\n      }\n      pending_text += role_prefix;\n\n      if (msg.content.IsParts()) {\n        for (const auto& item : msg.content.Parts()) {\n          auto it_type = item.find(\"type\");\n          if (it_type == item.end()) {\n            return TResult::Error(\"The content of a message does not have \\\"type\\\" field\");\n          }\n          if (it_type->second == \"text\") {\n            auto it_text = item.find(\"text\");\n            if (it_text == item.end()) {\n              return TResult::Error(\n                  \"The text type content of a message does not have \\\"text\\\" field\");\n            }\n            // replace placeholder[ROLE] with input message from role\n            pending_text += conv.GetRoleText(msg.role, it_text->second, fn_call_string);\n          } else if (it_type->second == \"image_url\") {\n            if (item.find(\"image_url\") == item.end()) {\n              return TResult::Error(\"Content should have an image_url field\");\n            }\n            std::string image_url =\n                item.at(\"image_url\");  // TODO(mlc-team): According to OpenAI API reference this\n                                       // should be a map, with a \"url\" key containing the URL, but\n                                       // we are just assuming this as the URL for now\n            std::string base64_image = image_url.substr(image_url.find(\",\") + 1);\n            Result<Tensor> image_data_res = LoadImageFromBase64(base64_image);\n            if (image_data_res.IsErr()) {\n              return TResult::Error(image_data_res.UnwrapErr());\n            }\n            if (!config.vision_config.has_value()) {\n              return TResult::Error(\"Vision config is required for image input\");\n            }\n            int image_size = config.vision_config.value().image_size;\n            int patch_size = config.vision_config.value().patch_size;\n\n            int embed_size = (image_size * image_size) / (patch_size * patch_size);\n\n            Tensor image_data = image_data_res.Unwrap();\n            std::vector<int64_t> new_shape = {1, image_size, image_size, 3};\n            Tensor image_tensor = image_data.CreateView(new_shape, image_data.DataType());\n            // TODO: Not sure if commenting will affect other functions. But\n            // python part will do clip preprocessing. auto image_tensor =\n            // ClipPreprocessor(image_data_res.Unwrap(), image_size, device);\n            // lazily commit text data\n            if (pending_text.length() != 0) {\n              message_list.push_back(TextData(pending_text));\n              pending_text = \"\";\n            }\n            message_list.push_back(ImageData(image_tensor, embed_size));\n          } else {\n            return TResult::Error(\"Unsupported content type: \" + it_type->second);\n          }\n        }\n      } else {\n        TVM_FFI_ICHECK(msg.content.IsText());\n        pending_text += conv.GetRoleText(msg.role, msg.content.Text(), fn_call_string);\n      }\n      pending_text += seperator;\n    }\n    return std::nullopt;\n  };\n\n  if (auto err = f_process_messages(conv.messages)) {\n    return err.value();\n  }\n  if (auto err = f_process_messages(request.messages)) {\n    return err.value();\n  }\n  // append last assistant begin message\n  ChatCompletionMessage last_assistant_begin;\n  last_assistant_begin.role = \"assistant\";\n  last_assistant_begin.content = std::nullopt;\n  if (auto err = f_process_messages({last_assistant_begin})) {\n    return err.value();\n  }\n  if (pending_text.length() != 0) {\n    message_list.push_back(TextData(pending_text));\n  }\n  // Handle system_prefix_token_ids\n  if (conv.system_prefix_token_ids.has_value()) {\n    message_list.insert(message_list.begin(), TokenData(conv.system_prefix_token_ids.value()));\n  }\n  return TResult::Ok(message_list);\n}\n\nResult<Conversation> Conversation::FromJSON(const tvm::ffi::json::Object& json_obj) {\n  using TResult = Result<Conversation>;\n  Conversation conv;\n\n  Result<std::optional<std::string>> name_res =\n      json::LookupOptionalWithResultReturn<std::string>(json_obj, \"name\");\n  if (name_res.IsErr()) {\n    return TResult::Error(name_res.UnwrapErr());\n  }\n  conv.name = name_res.Unwrap();\n\n  Result<std::string> system_template_res =\n      json::LookupWithResultReturn<std::string>(json_obj, \"system_template\");\n  if (system_template_res.IsErr()) {\n    return TResult::Error(system_template_res.UnwrapErr());\n  }\n  conv.system_template = system_template_res.Unwrap();\n\n  Result<std::string> system_message_res =\n      json::LookupWithResultReturn<std::string>(json_obj, \"system_message\");\n  if (system_message_res.IsErr()) {\n    return TResult::Error(system_message_res.UnwrapErr());\n  }\n  conv.system_message = system_message_res.Unwrap();\n\n  Result<std::optional<tvm::ffi::json::Array>> system_prefix_token_ids_arr_res =\n      json::LookupOptionalWithResultReturn<tvm::ffi::json::Array>(json_obj,\n                                                                  \"system_prefix_token_ids\");\n  if (system_prefix_token_ids_arr_res.IsErr()) {\n    return TResult::Error(system_prefix_token_ids_arr_res.UnwrapErr());\n  }\n  std::optional<tvm::ffi::json::Array> system_prefix_token_ids_arr =\n      system_prefix_token_ids_arr_res.Unwrap();\n  if (system_prefix_token_ids_arr.has_value()) {\n    std::vector<int> system_prefix_token_ids;\n    system_prefix_token_ids.reserve(system_prefix_token_ids_arr.value().size());\n    for (const auto& token_id : system_prefix_token_ids_arr.value()) {\n      if (!token_id.try_cast<int64_t>().has_value()) {\n        return TResult::Error(\"A system prefix token id is not integer.\");\n      }\n      system_prefix_token_ids.push_back(static_cast<int>(token_id.cast<int64_t>()));\n    }\n    conv.system_prefix_token_ids = std::move(system_prefix_token_ids);\n  }\n\n  Result<bool> add_role_after_system_message_res =\n      json::LookupWithResultReturn<bool>(json_obj, \"add_role_after_system_message\");\n  if (add_role_after_system_message_res.IsErr()) {\n    return TResult::Error(add_role_after_system_message_res.UnwrapErr());\n  }\n  conv.add_role_after_system_message = add_role_after_system_message_res.Unwrap();\n\n  Result<tvm::ffi::json::Object> roles_object_res =\n      json::LookupWithResultReturn<tvm::ffi::json::Object>(json_obj, \"roles\");\n  if (roles_object_res.IsErr()) {\n    return TResult::Error(roles_object_res.UnwrapErr());\n  }\n  for (const auto& role : roles_object_res.Unwrap()) {\n    if (!role.second.try_cast<std::string>().has_value()) {\n      return TResult::Error(\"A role value in the conversation template is not a string.\");\n    }\n    conv.roles[role.first.cast<tvm::ffi::String>()] = role.second.cast<std::string>();\n  }\n\n  Result<std::optional<tvm::ffi::json::Object>> role_templates_object_res =\n      json::LookupOptionalWithResultReturn<tvm::ffi::json::Object>(json_obj, \"role_templates\");\n  if (role_templates_object_res.IsErr()) {\n    return TResult::Error(role_templates_object_res.UnwrapErr());\n  }\n  std::optional<tvm::ffi::json::Object> role_templates_object = role_templates_object_res.Unwrap();\n  if (role_templates_object.has_value()) {\n    for (const auto& [role, msg] : role_templates_object.value()) {\n      if (!msg.try_cast<std::string>().has_value()) {\n        return TResult::Error(\"A value in \\\"role_templates\\\" is not a string.\");\n      }\n      conv.role_templates[role.cast<tvm::ffi::String>()] = msg.cast<std::string>();\n    }\n  }\n\n  Result<tvm::ffi::json::Array> messages_arr_res =\n      json::LookupWithResultReturn<tvm::ffi::json::Array>(json_obj, \"messages\");\n  if (messages_arr_res.IsErr()) {\n    return TResult::Error(messages_arr_res.UnwrapErr());\n  }\n  for (const auto& message : messages_arr_res.Unwrap()) {\n    if (!message.try_cast<tvm::ffi::json::Array>().has_value() ||\n        message.cast<tvm::ffi::json::Array>().size() != 2) {\n      return TResult::Error(\n          \"A message in the conversation template is not an array of [role, content].\");\n    }\n    tvm::ffi::json::Array message_arr = message.cast<tvm::ffi::json::Array>();\n    if (!message_arr[0].try_cast<std::string>().has_value()) {\n      return TResult::Error(\"The role of a message in the conversation template is not a string.\");\n    }\n    std::string role = message_arr[0].cast<std::string>();\n    // content can be a string or an array of objects\n    if (message_arr[1].try_cast<std::string>().has_value()) {\n      ChatCompletionMessage msg;\n      msg.role = role;\n      msg.content = message_arr[1].cast<std::string>();\n      conv.messages.push_back(msg);\n      continue;\n    } else if (message_arr[1].try_cast<tvm::ffi::json::Array>().has_value()) {\n      tvm::ffi::json::Array content_arr = message_arr[1].cast<tvm::ffi::json::Array>();\n      std::vector<std::unordered_map<std::string, std::string>> content;\n      content.reserve(content_arr.size());\n      for (const auto& item : content_arr) {\n        if (!item.try_cast<tvm::ffi::json::Object>().has_value()) {\n          return TResult::Error(\"The content of conversation template message is not an object\");\n        }\n        std::unordered_map<std::string, std::string> item_map;\n        for (const auto& [key, value] : item.cast<tvm::ffi::json::Object>()) {\n          item_map[key.cast<tvm::ffi::String>()] = tvm::ffi::json::Stringify(value);\n        }\n        content.push_back(std::move(item_map));\n      }\n      ChatCompletionMessage msg;\n      msg.role = role;\n      msg.content = content;\n      conv.messages.push_back(msg);\n      continue;\n    } else {\n      return TResult::Error(\n          \"The content of a message in the conversation template is not a string or an array.\");\n    }\n  }\n\n  Result<tvm::ffi::json::Array> seps_arr_res =\n      json::LookupWithResultReturn<tvm::ffi::json::Array>(json_obj, \"seps\");\n  if (seps_arr_res.IsErr()) {\n    return TResult::Error(seps_arr_res.UnwrapErr());\n  }\n  std::vector<std::string> seps;\n  for (const auto& sep : seps_arr_res.Unwrap()) {\n    if (!sep.try_cast<std::string>().has_value()) {\n      return TResult::Error(\"A separator (\\\"seps\\\") of the conversation template is not a string\");\n    }\n    conv.seps.push_back(sep.cast<std::string>());\n  }\n\n  Result<std::string> role_content_sep_res =\n      json::LookupWithResultReturn<std::string>(json_obj, \"role_content_sep\");\n  if (role_content_sep_res.IsErr()) {\n    return TResult::Error(role_content_sep_res.UnwrapErr());\n  }\n  conv.role_content_sep = role_content_sep_res.Unwrap();\n\n  Result<std::string> role_empty_sep_res =\n      json::LookupWithResultReturn<std::string>(json_obj, \"role_empty_sep\");\n  if (role_empty_sep_res.IsErr()) {\n    return TResult::Error(role_empty_sep_res.UnwrapErr());\n  }\n  conv.role_empty_sep = role_empty_sep_res.Unwrap();\n\n  Result<tvm::ffi::json::Array> stop_str_arr_res =\n      json::LookupWithResultReturn<tvm::ffi::json::Array>(json_obj, \"stop_str\");\n  if (stop_str_arr_res.IsErr()) {\n    return TResult::Error(stop_str_arr_res.UnwrapErr());\n  }\n  for (const auto& stop : stop_str_arr_res.Unwrap()) {\n    if (!stop.try_cast<std::string>().has_value()) {\n      return TResult::Error(\n          \"A stop string (\\\"stop_str\\\") of the conversation template is not a string.\");\n    }\n    conv.stop_str.push_back(stop.cast<std::string>());\n  }\n\n  Result<tvm::ffi::json::Array> stop_token_ids_arr_res =\n      json::LookupWithResultReturn<tvm::ffi::json::Array>(json_obj, \"stop_token_ids\");\n  if (stop_token_ids_arr_res.IsErr()) {\n    return TResult::Error(stop_token_ids_arr_res.UnwrapErr());\n  }\n  for (const auto& stop : stop_token_ids_arr_res.Unwrap()) {\n    if (!stop.try_cast<int64_t>().has_value()) {\n      return TResult::Error(\n          \"A stop token id (\\\"stop_token_ids\\\") of the conversation template is not an integer.\");\n    }\n    conv.stop_token_ids.push_back(static_cast<int>(stop.cast<int64_t>()));\n  }\n  return TResult::Ok(conv);\n}\n\nResult<Conversation> Conversation::FromJSON(const std::string& json_str) {\n  Result<tvm::ffi::json::Object> json_obj = json::ParseToJSONObjectWithResultReturn(json_str);\n  if (json_obj.IsErr()) {\n    return Result<Conversation>::Error(json_obj.UnwrapErr());\n  }\n  return Conversation::FromJSON(json_obj.Unwrap());\n}\n\n}  // namespace json_ffi\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/json_ffi/conv_template.h",
    "content": "#ifndef MLC_LLM_JSON_FFI_CONV_TEMPLATE_H\n#define MLC_LLM_JSON_FFI_CONV_TEMPLATE_H\n\n#include <tvm/ffi/extra/json.h>\n\n#include <iostream>\n#include <map>\n#include <optional>\n#include <string>\n#include <typeinfo>\n#include <variant>\n#include <vector>\n\n#include \"../serve/data.h\"\n#include \"../support/result.h\"\n#include \"openai_api_protocol.h\"\n\nusing namespace mlc::llm::serve;\n\nnamespace mlc {\nnamespace llm {\nnamespace json_ffi {\n\n/****************** Model vision config ******************/\n\n/*! \\brief Defines the Vision config of the model (if present) */\nclass ModelVisionConfig {\n public:\n  int hidden_size;\n  int image_size;\n  int intermediate_size;\n  int num_attention_heads;\n  int num_hidden_layers;\n  int patch_size;\n  int projection_dim;\n  int vocab_size;\n  std::string dtype;\n  int num_channels;\n  double layer_norm_eps;\n\n  static ModelVisionConfig FromJSON(const tvm::ffi::json::Object& json_obj);\n};\n\n/****************** Model config ******************/\n\n/*! \\brief Defines the config of the model.\nPopulated from \"model_config\" field in mlc-chat-config.json */\nclass ModelConfig {\n public:\n  int vocab_size;\n  int context_window_size;\n  int sliding_window_size;\n  int prefill_chunk_size;\n  int tensor_parallel_shards;\n  int pipeline_parallel_stages;\n  int max_batch_size;\n  std::optional<ModelVisionConfig> vision_config = std::nullopt;\n\n  static ModelConfig FromJSON(const tvm::ffi::json::Object& json_obj);\n};\n\n/****************** Conversation template ******************/\n\nenum class MessagePlaceholders { SYSTEM, USER, ASSISTANT, TOOL, FUNCTION };\n\nMessagePlaceholders MessagePlaceholderFromString(const std::string& role);\n\n/**\n * @brief A struct that specifies the convention template of conversation\n * and contains the conversation history.\n */\nstruct Conversation {\n  // Optional name of the template.\n  std::optional<std::string> name = std::nullopt;\n\n  // The system prompt template, it optionally contains the system\n  // message placeholder, and the placeholder will be replaced with\n  // the system message below.\n  std::string system_template;\n\n  // The content of the system prompt (without the template format).\n  std::string system_message;\n\n  // The system token ids to be prepended at the beginning of tokenized\n  // generated prompt.\n  std::optional<std::vector<int>> system_prefix_token_ids = std::nullopt;\n\n  // Whether or not to append user role and separator after the system message.\n  // This is mainly for [INST] [/INST] style prompt format\n  bool add_role_after_system_message = true;\n\n  // The conversation roles\n  std::unordered_map<std::string, std::string> roles;\n\n  // The roles prompt template, it optionally contains the defaults\n  // message placeholders and will be replaced by actual content\n  std::unordered_map<std::string, std::string> role_templates;\n\n  // The conversation history messages.\n  // Each message is a pair of strings, denoting \"(role, content)\".\n  // The content can be None.\n  std::vector<ChatCompletionMessage> messages;\n\n  // The separators between messages when concatenating into a single prompt.\n  // List size should be either 1 or 2.\n  // - When size is 1, the separator will be used between adjacent messages.\n  // - When size is 2, seps[0] is used after user message, and\n  //   seps[1] is used after assistant message.\n  std::vector<std::string> seps;\n\n  // The separator between the role and the content in a message.\n  std::string role_content_sep;\n\n  // The separator between the role and empty contents.\n  std::string role_empty_sep;\n\n  // The stop criteria\n  std::vector<std::string> stop_str;\n  std::vector<int> stop_token_ids;\n\n  Conversation();\n\n  /*!\n   * \\brief Get the system text(with the prompt template) given the system prompt message\n   * \\param system_msg The system prompt message.\n   * \\return The created system text.\n   */\n  std::string GetSystemText(const std::string& system_msg) const;\n\n  /*!\n   * \\brief replace the content from role by the correct role text in template\n   * \\param role The input role\n   * \\param content The input content from the role\n   * \\param fn_call_str The function calling string if any.\n   * \\return The created text.\n   */\n  std::string GetRoleText(const std::string& role, const std::string& content,\n                          const std::optional<std::string>& fn_call_str) const;\n\n  /*! \\brief Create a Conversation instance from the given JSON object. */\n  static Result<Conversation> FromJSON(const tvm::ffi::json::Object& json);\n  /*! \\brief Parse and create a Conversation instance from the given JSON string. */\n  static Result<Conversation> FromJSON(const std::string& json_str);\n};\n\n/*! \\brief Create the list of prompts from the messages based on the conversation template. */\nResult<std::vector<Data>> CreatePrompt(const Conversation& conv,\n                                       const ChatCompletionRequest& request,\n                                       const ModelConfig& config, DLDevice device);\n\n}  // namespace json_ffi\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_JSON_FFI_CONV_TEMPLATE_H\n"
  },
  {
    "path": "cpp/json_ffi/image_utils.cc",
    "content": "#include \"image_utils.h\"\n\n#include <tvm/support/io.h>\n\n#include \"../../3rdparty/tvm/src/support/base64.h\"\n#define STB_IMAGE_IMPLEMENTATION\n#include \"stb_image.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace json_ffi {\n\nusing namespace tvm::runtime;\n\nclass MemoryBufferStream : public tvm::support::Stream {\n public:\n  using Stream::Read;\n  using Stream::Write;\n\n  MemoryBufferStream(const char* data, size_t size) : data_(data), size_(size), pos_(0) {}\n\n  size_t Read(void* ptr, size_t size) override {\n    size_t remaining = size_ - pos_;\n    if (size > remaining) {\n      size = remaining;\n    }\n    if (size == 0) {\n      return 0;\n    }\n    std::memcpy(ptr, data_ + pos_, size);\n    pos_ += size;\n    return size;\n  }\n\n  size_t Write(const void* ptr, size_t size) override {\n    TVM_FFI_THROW(InternalError) << \"MemoryBufferStream does not support write\";\n    return 0;\n  }\n\n private:\n  const char* data_;\n  size_t size_;\n  size_t pos_;\n};\n\nsize_t Base64DecodedSize(const std::string& base64_str) {\n  size_t len = base64_str.size();\n  size_t padding = 0;\n  if (base64_str[len - 1] == '=') {\n    padding++;\n  }\n  if (base64_str[len - 2] == '=') {\n    padding++;\n  }\n  return 3 * len / 4 - padding;\n}\n\nResult<Tensor> LoadImageFromBase64(const std::string& base64_str) {\n  using TResult = Result<Tensor>;\n  MemoryBufferStream stream(base64_str.c_str(), base64_str.size());\n  tvm::support::Base64InStream base64_stream(&stream);\n  size_t decoded_size = Base64DecodedSize(base64_str);\n  std::vector<unsigned char> decoded(decoded_size);\n  base64_stream.InitPosition();\n  base64_stream.Read((void*)decoded.data(), decoded_size);\n  int width, height, num_channels;\n  unsigned char* image_data =\n      stbi_load_from_memory(decoded.data(), decoded_size, &width, &height, &num_channels, 3);\n  if (!image_data) {\n    return TResult::Error(stbi_failure_reason());\n  }\n  auto image_tensor = Tensor::Empty({height, width, 3}, {kDLUInt, 8, 1}, {kDLCPU, 0});\n  image_tensor.CopyFromBytes((void*)image_data, width * height * 3);\n  stbi_image_free(image_data);\n  return TResult::Ok(image_tensor);\n}\n\nTensor ClipPreprocessor(Tensor image_data, int target_size, DLDevice device) {\n  int height = image_data->shape[0];\n  int width = image_data->shape[1];\n  // Resize\n  const int short_side = width < height ? width : height;\n  const int long_side = width > height ? width : height;\n  const int new_short_side = target_size;\n  const int new_long_side = (int)(new_short_side * (long_side / (float)short_side));\n  const int new_width = width < height ? new_short_side : new_long_side;\n  const int new_height = width > height ? new_short_side : new_long_side;\n\n  std::vector<float> processed_image_data(new_width * new_height * 3);\n\n  // Bilinear Interpolation\n  for (int y = 0; y < new_height; y++) {\n    for (int x = 0; x < new_width; x++) {\n      const float x_ratio = float(width - 1) / new_width;\n      const float y_ratio = float(height - 1) / new_height;\n      const int x1 = int(x_ratio * x);\n      const int y1 = int(y_ratio * y);\n      const int x2 = x1 + 1;\n      const int y2 = y1 + 1;\n      const float x_diff = x_ratio * x - x1;\n      const float y_diff = y_ratio * y - y1;\n      for (int c = 0; c < 3; c++) {\n        const uint8_t top_left = ((uint8_t*)image_data->data)[(y1 * width + x1) * 3 + c];\n        const uint8_t top_right = ((uint8_t*)image_data->data)[(y1 * width + x2) * 3 + c];\n        const uint8_t bottom_left = ((uint8_t*)image_data->data)[(y2 * width + x1) * 3 + c];\n        const uint8_t bottom_right = ((uint8_t*)image_data->data)[(y2 * width + x2) * 3 + c];\n        processed_image_data[(y * new_width + x) * 3 + c] =\n            (float)(int(top_left * (1 - x_diff) * (1 - y_diff) + top_right * x_diff * (1 - y_diff) +\n                        bottom_left * y_diff * (1 - x_diff) + bottom_right * x_diff * y_diff));\n      }\n    }\n  }\n\n  // Center crop\n  const int crop_x = (new_width - target_size) / 2;\n  const int crop_y = (new_height - target_size) / 2;\n  std::vector<float> cropped_image_data(target_size * target_size * 3);\n  for (int y = 0; y < target_size; y++) {\n    for (int x = 0; x < target_size; x++) {\n      for (int c = 0; c < 3; c++) {\n        cropped_image_data[(y * target_size + x) * 3 + c] =\n            processed_image_data[((y + crop_y) * new_width + x + crop_x) * 3 + c];\n      }\n    }\n  }\n\n  // Rescale\n  for (int i = 0; i < target_size * target_size * 3; i++) {\n    cropped_image_data[i] = cropped_image_data[i] / 255.0f;\n  }\n\n  // Normalize\n  const float IMAGE_MEAN[] = {0.48145466f, 0.4578275f, 0.40821073f};\n  const float IMAGE_STD[] = {0.26862954f, 0.26130258f, 0.27577711f};\n  for (int i = 0; i < target_size * target_size * 3; i++) {\n    const int c = i % 3;\n    cropped_image_data[i] = (cropped_image_data[i] - IMAGE_MEAN[c]) / IMAGE_STD[c];\n  }\n\n  std::vector<float> image_data_channel_first(target_size * target_size * 3);\n  for (int y = 0; y < target_size; y++) {\n    for (int x = 0; x < target_size; x++) {\n      for (int c = 0; c < 3; c++) {\n        image_data_channel_first[c * target_size * target_size + y * target_size + x] =\n            cropped_image_data[(y * target_size + x) * 3 + c];\n      }\n    }\n  }\n\n  // Create Tensor\n  auto image_tensor = Tensor::Empty({1, 3, target_size, target_size}, {kDLFloat, 32, 1}, device);\n  image_tensor.CopyFromBytes((void*)image_data_channel_first.data(),\n                             target_size * target_size * 3 * sizeof(float));\n\n  return image_tensor;\n}\n\n}  // namespace json_ffi\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/json_ffi/image_utils.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file json_ffi/image_utils.h\n * \\brief The header of Image utils for JSON FFI Engine in MLC LLM.\n */\n#ifndef MLC_LLM_JSON_FFI_IMAGE_UTILS_H_\n#define MLC_LLM_JSON_FFI_IMAGE_UTILS_H_\n\n#include <tvm/runtime/tensor.h>\n\n#include <optional>\n#include <string>\n\n#include \"../support/result.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace json_ffi {\n\n/*! \\brief Load a base64 encoded image string into a CPU Tensor of shape {height, width, 3} */\nResult<tvm::runtime::Tensor> LoadImageFromBase64(const std::string& base64_str);\n\n/*! \\brief Preprocess the CPU image for CLIP encoder and return an Tensor on the given device */\ntvm::runtime::Tensor ClipPreprocessor(tvm::runtime::Tensor image_data, int target_size,\n                                      DLDevice device);\n\n}  // namespace json_ffi\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_JSON_FFI_IMAGE_UTILS_H_\n"
  },
  {
    "path": "cpp/json_ffi/json_ffi_engine.cc",
    "content": "#include \"json_ffi_engine.h\"\n\n#include <tvm/ffi/extra/json.h>\n#include <tvm/ffi/function.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/runtime/module.h>\n\n#include <filesystem>\n#include <fstream>\n\n#include \"../serve/model.h\"\n#include \"../support/json_parser.h\"\n#include \"../support/result.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace json_ffi {\n\nusing namespace tvm::runtime;\n\nJSONFFIEngine::JSONFFIEngine() { engine_ = serve::ThreadedEngine::Create(); }\n\nbool JSONFFIEngine::ChatCompletion(std::string request_json_str, std::string request_id) {\n  bool success = this->AddRequest(request_json_str, request_id);\n  if (!success) {\n    this->StreamBackError(request_id);\n  }\n  return success;\n}\n\nvoid JSONFFIEngine::StreamBackError(std::string request_id) {\n  ChatCompletionMessage delta;\n  delta.content = this->err_;\n  delta.role = \"assistant\";\n\n  ChatCompletionStreamResponseChoice choice;\n  choice.finish_reason = FinishReason::error;\n  choice.index = 0;\n  choice.delta = delta;\n\n  ChatCompletionStreamResponse response;\n  response.id = request_id;\n  response.choices = std::vector<ChatCompletionStreamResponseChoice>{choice};\n  response.model = \"json_ffi\";  // TODO: Return model name from engine (or from args)\n  response.system_fingerprint = \"\";\n\n  tvm::ffi::json::Array response_arr;\n  response_arr.push_back(response.AsJSON());\n\n  // now stream back the final usage block, which is required.\n  // NOTE: always stream back final usage block as it is an\n  // invariant of the system\n  response.choices.clear();\n  tvm::ffi::json::Object dummy_usage;\n  dummy_usage.Set(\"prompt_tokens\", static_cast<int64_t>(0));\n  dummy_usage.Set(\"completion_tokens\", static_cast<int64_t>(0));\n  dummy_usage.Set(\"total_tokens\", static_cast<int64_t>(0));\n  response.usage = tvm::ffi::json::Value(dummy_usage);\n  response_arr.push_back(response.AsJSON());\n\n  std::string stream_back_json = tvm::ffi::json::Stringify(response_arr);\n  this->request_stream_callback_(stream_back_json);\n}\n\nbool JSONFFIEngine::AddRequest(std::string request_json_str, std::string request_id) {\n  Result<ChatCompletionRequest> request_res = ChatCompletionRequest::FromJSON(request_json_str);\n  if (request_res.IsErr()) {\n    err_ = request_res.UnwrapErr();\n    return false;\n  }\n  ChatCompletionRequest request = request_res.Unwrap();\n  Array<Data> inputs;\n  Array<String> stop_strs;\n  bool is_special_request =\n      (request.debug_config.has_value() &&\n       request.debug_config.value().special_request != SpecialRequestKind::kNone);\n  // special request does not have to go through prompt construction\n  if (!is_special_request) {\n    // get prompt: note, assistant was appended in the end.\n    Result<std::vector<Data>> inputs_obj =\n        CreatePrompt(this->conv_template_, request, this->model_config_, this->device_);\n    if (inputs_obj.IsErr()) {\n      err_ = inputs_obj.UnwrapErr();\n      return false;\n    }\n    inputs = inputs_obj.Unwrap();\n\n    stop_strs.reserve(this->conv_template_.stop_str.size());\n    for (const std::string& stop_str : this->conv_template_.stop_str) {\n      stop_strs.push_back(stop_str);\n    }\n    if (request.stop.has_value()) {\n      stop_strs.reserve(stop_strs.size() + request.stop.value().size());\n      for (const std::string& stop_str : request.stop.value()) {\n        stop_strs.push_back(stop_str);\n      }\n    }\n  }\n  // create a generation config from request\n  const auto& default_gen_cfg = default_generation_config_;\n  auto gen_cfg = tvm::ffi::make_object<GenerationConfigNode>();\n  gen_cfg->n = request.n;\n  gen_cfg->temperature = request.temperature.value_or(default_gen_cfg->temperature);\n  gen_cfg->top_p = request.top_p.value_or(default_gen_cfg->top_p);\n  gen_cfg->frequency_penalty =\n      request.frequency_penalty.value_or(default_gen_cfg->frequency_penalty);\n  gen_cfg->presence_penalty = request.presence_penalty.value_or(default_gen_cfg->presence_penalty);\n  gen_cfg->logprobs = request.logprobs;\n  gen_cfg->top_logprobs = request.top_logprobs;\n  gen_cfg->logit_bias = request.logit_bias.value_or(default_gen_cfg->logit_bias);\n  gen_cfg->seed = request.seed.value_or(std::random_device{}());\n  gen_cfg->max_tokens = request.max_tokens.value_or(default_gen_cfg->max_tokens);\n  gen_cfg->stop_strs = std::move(stop_strs);\n  gen_cfg->stop_token_ids = conv_template_.stop_token_ids;\n  gen_cfg->response_format = request.response_format.value_or(ResponseFormat());\n  gen_cfg->debug_config = request.debug_config.value_or(DebugConfig());\n\n  Result<GenerationConfig> res_gen_config = GenerationConfig::Validate(GenerationConfig(gen_cfg));\n  if (res_gen_config.IsErr()) {\n    err_ = res_gen_config.UnwrapErr();\n    return false;\n  }\n\n  Request engine_request(request_id, inputs, res_gen_config.Unwrap());\n\n  // setup request state\n  RequestState rstate;\n  rstate.model = request.model.value_or(\"\");\n  rstate.streamer.reserve(gen_cfg->n);\n  for (int i = 0; i < gen_cfg->n; ++i) {\n    rstate.streamer.push_back(TextStreamer(tokenizer_));\n  }\n  request_map_[request_id] = std::move(rstate);\n\n  this->engine_->AddRequest(engine_request);\n  return true;\n}\n\nbool JSONFFIEngine::Abort(std::string request_id) {\n  this->engine_->AbortRequest(request_id);\n  auto it = request_map_.find(request_id);\n  if (it != request_map_.end()) {\n    request_map_.erase(it);\n  }\n  return true;\n}\n\nstd::string JSONFFIEngine::GetLastError() { return err_; }\n\nvoid JSONFFIEngine::ExitBackgroundLoop() { this->engine_->ExitBackgroundLoop(); }\n\nJSONFFIEngine::~JSONFFIEngine() { this->ExitBackgroundLoop(); }\n\nclass JSONFFIEngineImpl : public JSONFFIEngine, public ffi::ModuleObj {\n public:\n  TVM_MODULE_VTABLE_BEGIN(\"mlc.json_ffi\");\n  TVM_MODULE_VTABLE_ENTRY(\"init_background_engine\", &JSONFFIEngineImpl::InitBackgroundEngine);\n  TVM_MODULE_VTABLE_ENTRY(\"reload\", &JSONFFIEngineImpl::Reload);\n  TVM_MODULE_VTABLE_ENTRY(\"unload\", &JSONFFIEngineImpl::Unload);\n  TVM_MODULE_VTABLE_ENTRY(\"reset\", &JSONFFIEngineImpl::Reset);\n  TVM_MODULE_VTABLE_ENTRY(\"chat_completion\", &JSONFFIEngineImpl::ChatCompletion);\n  TVM_MODULE_VTABLE_ENTRY(\"abort\", &JSONFFIEngineImpl::Abort);\n  TVM_MODULE_VTABLE_ENTRY(\"get_last_error\", &JSONFFIEngineImpl::GetLastError);\n  TVM_MODULE_VTABLE_ENTRY(\"run_background_loop\", &JSONFFIEngineImpl::RunBackgroundLoop);\n  TVM_MODULE_VTABLE_ENTRY(\"run_background_stream_back_loop\",\n                          &JSONFFIEngineImpl::RunBackgroundStreamBackLoop);\n  TVM_MODULE_VTABLE_ENTRY(\"exit_background_loop\", &JSONFFIEngineImpl::ExitBackgroundLoop);\n  TVM_MODULE_VTABLE_END();\n\n  void InitBackgroundEngine(int device_type, int device_id,\n                            Optional<Function> request_stream_callback) {\n    DLDevice device{static_cast<DLDeviceType>(device_type), device_id};\n    this->device_ = device;\n    TVM_FFI_ICHECK(request_stream_callback.defined())\n        << \"JSONFFIEngine requires request stream callback function, but it is not given.\";\n    this->request_stream_callback_ = request_stream_callback.value();\n\n    auto frequest_stream_callback_wrapper = [this](ffi::PackedArgs args, ffi::Any* ret) {\n      TVM_FFI_ICHECK_EQ(args.size(), 1);\n      Array<RequestStreamOutput> delta_outputs = args[0].cast<Array<RequestStreamOutput>>();\n      std::string responses = this->GetResponseFromStreamOutput(delta_outputs);\n      this->request_stream_callback_(responses);\n    };\n\n    request_stream_callback = Function(frequest_stream_callback_wrapper);\n    this->engine_->InitThreadedEngine(device, std::move(request_stream_callback), std::nullopt);\n  }\n\n  void Reload(String engine_config_json_str) {\n    this->engine_->Reload(engine_config_json_str);\n    this->default_generation_config_ = this->engine_->GetDefaultGenerationConfig();\n    auto engine_config = this->engine_->GetCompleteEngineConfig();\n\n    // Load conversation template.\n    Result<tvm::ffi::json::Object> model_config_json =\n        serve::Model::LoadModelConfig(engine_config->model);\n    TVM_FFI_ICHECK(model_config_json.IsOk()) << model_config_json.UnwrapErr();\n    const tvm::ffi::json::Object& model_config_json_unwrapped = model_config_json.Unwrap();\n    Result<Conversation> conv_template = Conversation::FromJSON(\n        json::Lookup<tvm::ffi::json::Object>(model_config_json_unwrapped, \"conv_template\"));\n    TVM_FFI_ICHECK(!conv_template.IsErr())\n        << \"Invalid conversation template JSON: \" << conv_template.UnwrapErr();\n    this->conv_template_ = conv_template.Unwrap();\n    this->model_config_ = ModelConfig::FromJSON(\n        json::Lookup<tvm::ffi::json::Object>(model_config_json_unwrapped, \"model_config\"));\n    this->tokenizer_ = Tokenizer::FromPath(engine_config->model);\n  }\n\n  void Unload() { this->engine_->Unload(); }\n\n  void Reset() { this->engine_->Reset(); }\n\n  void RunBackgroundLoop() { this->engine_->RunBackgroundLoop(); }\n\n  void RunBackgroundStreamBackLoop() { this->engine_->RunBackgroundStreamBackLoop(); }\n\n  String GetResponseFromStreamOutput(Array<RequestStreamOutput> delta_outputs) {\n    tvm::ffi::json::Array json_response_arr;\n    for (const auto& delta_output : delta_outputs) {\n      std::string request_id = delta_output->request_id;\n      auto request_state_it = request_map_.find(request_id);\n      if (request_state_it == request_map_.end()) continue;\n      RequestState& rstate = request_state_it->second;\n\n      // build the final usage messages\n      // invariant, we can always let other messages to come first\n      // then the final usage messages, as final usage is always last\n      if (delta_output->request_final_usage_json_str.has_value()) {\n        ChatCompletionStreamResponse response;\n        response.id = request_id;\n        response.model = rstate.model;\n        response.system_fingerprint = \"\";\n        std::string usage_json_str = delta_output->request_final_usage_json_str.value();\n        tvm::ffi::String parse_err;\n        auto usage_json = tvm::ffi::json::Parse(usage_json_str, &parse_err);\n        if (!parse_err.empty()) {\n          err_ = parse_err;\n        } else {\n          response.usage = usage_json;\n        }\n        json_response_arr.push_back(response.AsJSON());\n        request_map_.erase(request_state_it);\n        continue;\n      }\n      TVM_FFI_ICHECK_NE(delta_output->group_finish_reason.size(), 0);\n      TVM_FFI_ICHECK_EQ(delta_output->group_delta_token_ids.size(),\n                        delta_output->group_finish_reason.size());\n      TVM_FFI_ICHECK_EQ(delta_output->group_delta_token_ids.size(), rstate.streamer.size());\n\n      ChatCompletionStreamResponse response;\n      response.id = request_id;\n      response.model = rstate.model;\n      response.system_fingerprint = \"\";\n\n      for (size_t i = 0; i < delta_output->group_finish_reason.size(); ++i) {\n        // choice\n        ChatCompletionStreamResponseChoice choice;\n        Optional<String> finish_reason = delta_output->group_finish_reason[i];\n        if (finish_reason.has_value()) {\n          if (finish_reason.value() == \"stop\") {\n            choice.finish_reason = FinishReason::stop;\n          } else if (finish_reason.value() == \"length\") {\n            choice.finish_reason = FinishReason::length;\n          } else if (finish_reason.value() == \"tool_calls\") {\n            choice.finish_reason = FinishReason::tool_calls;\n          } else if (finish_reason.value() == \"error\") {\n            choice.finish_reason = FinishReason::error;\n          }\n        } else {\n          choice.finish_reason = std::nullopt;\n        }\n        choice.index = static_cast<int>(i);\n        ChatCompletionMessage delta;\n        // Size of delta_output->group_delta_token_ids Array should be 1\n        const IntTuple& delta_token_ids = delta_output->group_delta_token_ids[i];\n        std::vector<int32_t> delta_token_ids_vec(delta_token_ids.begin(), delta_token_ids.end());\n        std::string content = rstate.streamer[i]->Put(delta_token_ids_vec);\n        if (finish_reason.has_value()) {\n          content += rstate.streamer[i]->Finish();\n        }\n        if (!content.empty()) {\n          delta.content = content;\n        }\n        delta.role = \"assistant\";\n        choice.delta = delta;\n        if (!choice.delta.content.IsNull() || choice.finish_reason.has_value()) {\n          response.choices.push_back(choice);\n        }\n      }\n      // if it is not the usage block, choices cannot be empty\n      if (!response.choices.empty()) {\n        json_response_arr.push_back(response.AsJSON());\n      }\n    }\n    return tvm::ffi::json::Stringify(json_response_arr);\n  }\n};\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  namespace refl = tvm::ffi::reflection;\n  refl::GlobalDef().def(\"mlc.json_ffi.CreateJSONFFIEngine\",\n                        []() { return ffi::Module(tvm::ffi::make_object<JSONFFIEngineImpl>()); });\n}\n\n}  // namespace json_ffi\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/json_ffi/json_ffi_engine.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file json_ffi/json_ffi_engine.h\n * \\brief The header of JSON FFI engine in MLC LLM.\n */\n#ifndef MLC_LLM_JSON_FFI_JSON_FFI_ENGINE_H_\n#define MLC_LLM_JSON_FFI_JSON_FFI_ENGINE_H_\n\n#include <string>\n\n#include \"../serve/threaded_engine.h\"\n#include \"../tokenizers/streamer.h\"\n#include \"conv_template.h\"\n#include \"openai_api_protocol.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace json_ffi {\n\nusing namespace tvm::runtime;\nusing namespace mlc::llm::serve;\n\n/*!\n * \\brief // Todo: document this class, fields and member functions\n */\nclass JSONFFIEngine {\n public:\n  JSONFFIEngine();\n\n  ~JSONFFIEngine();\n\n  bool ChatCompletion(std::string request_json_str, std::string request_id);\n\n  bool AddRequest(std::string request_json_str, std::string request_id);\n\n  void StreamBackError(std::string request_id);\n\n  bool Abort(std::string request_id);\n\n  std::string GetLastError();\n\n  void ExitBackgroundLoop();\n\n protected:\n  /*! \\brief local request state entry, one per reply stream. */\n  struct RequestState {\n    /*! \\brief model to fill in reply. */\n    std::string model;\n    /*! \\brief text streamer for each stream */\n    std::vector<TextStreamer> streamer;\n  };\n\n  std::unique_ptr<ThreadedEngine> engine_;\n  std::string err_;\n  Function request_stream_callback_;\n  // tokenizer\n  Tokenizer tokenizer_;\n  // conversation template\n  Conversation conv_template_;\n  // generation config\n  GenerationConfig default_generation_config_;\n  // model config\n  ModelConfig model_config_;\n  // local device\n  DLDevice device_;\n  // request state map\n  std::unordered_map<String, RequestState> request_map_;\n};\n\n}  // namespace json_ffi\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_JSON_FFI_JSON_FFI_ENGINE_H_\n"
  },
  {
    "path": "cpp/json_ffi/openai_api_protocol.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file json_ffi/openai_api_protocol.cc\n * \\brief The implementation of OpenAI API Protocol in MLC LLM.\n */\n#include \"openai_api_protocol.h\"\n\n#include \"../support/json_parser.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace json_ffi {\n\nResult<ChatFunction> ChatFunction::FromJSON(const tvm::ffi::json::Object& json_obj) {\n  using TResult = Result<ChatFunction>;\n  ChatFunction chat_func;\n\n  // description\n  Result<std::optional<std::string>> description_res =\n      json::LookupOptionalWithResultReturn<std::string>(json_obj, \"description\");\n  if (description_res.IsErr()) {\n    return TResult::Error(description_res.UnwrapErr());\n  }\n  chat_func.description = description_res.Unwrap();\n\n  // name\n  Result<std::string> name_res = json::LookupWithResultReturn<std::string>(json_obj, \"name\");\n  if (name_res.IsErr()) {\n    return TResult::Error(name_res.UnwrapErr());\n  }\n  chat_func.name = name_res.Unwrap();\n\n  // parameters\n  Result<tvm::ffi::json::Object> parameters_obj_res =\n      json::LookupWithResultReturn<tvm::ffi::json::Object>(json_obj, \"parameters\");\n  if (parameters_obj_res.IsErr()) {\n    return TResult::Error(parameters_obj_res.UnwrapErr());\n  }\n  tvm::ffi::json::Object parameters_obj = parameters_obj_res.Unwrap();\n  chat_func.parameters.reserve(parameters_obj.size());\n  for (const auto& [key, value] : parameters_obj) {\n    chat_func.parameters[key.cast<tvm::ffi::String>()] = tvm::ffi::json::Stringify(value);\n  }\n\n  return TResult::Ok(chat_func);\n}\n\ntvm::ffi::json::Object ChatFunction::AsJSON() const {\n  tvm::ffi::json::Object obj;\n  if (this->description.has_value()) {\n    obj.Set(\"description\", this->description.value());\n  }\n  obj.Set(\"name\", this->name);\n  tvm::ffi::json::Object parameters_obj;\n  for (const auto& pair : this->parameters) {\n    parameters_obj.Set(pair.first, pair.second);\n  }\n  obj.Set(\"parameters\", parameters_obj);\n  return obj;\n}\n\nResult<ChatTool> ChatTool::FromJSON(const tvm::ffi::json::Object& json_obj) {\n  using TResult = Result<ChatTool>;\n  ChatTool chatTool;\n\n  // function\n  Result<tvm::ffi::json::Object> function_obj_res =\n      json::LookupWithResultReturn<tvm::ffi::json::Object>(json_obj, \"function\");\n  if (function_obj_res.IsErr()) {\n    return TResult::Error(function_obj_res.UnwrapErr());\n  }\n  Result<ChatFunction> function = ChatFunction::FromJSON(function_obj_res.Unwrap());\n  if (function.IsErr()) {\n    return TResult::Error(function.UnwrapErr());\n  }\n  chatTool.function = function.Unwrap();\n\n  return TResult::Ok(chatTool);\n}\n\ntvm::ffi::json::Object ChatTool::AsJSON() const {\n  tvm::ffi::json::Object obj;\n  obj.Set(\"type\", \"function\");\n  obj.Set(\"function\", this->function.AsJSON());\n  return obj;\n}\n\nResult<ChatFunctionCall> ChatFunctionCall::FromJSON(const tvm::ffi::json::Object& json_obj) {\n  using TResult = Result<ChatFunctionCall>;\n  ChatFunctionCall chat_func_call;\n\n  // name\n  Result<std::string> name_res = json::LookupWithResultReturn<std::string>(json_obj, \"name\");\n  if (name_res.IsErr()) {\n    return TResult::Error(name_res.UnwrapErr());\n  }\n  chat_func_call.name = name_res.Unwrap();\n\n  // arguments\n  Result<std::optional<tvm::ffi::json::Object>> arguments_obj_res =\n      json::LookupOptionalWithResultReturn<tvm::ffi::json::Object>(json_obj, \"arguments\");\n  if (arguments_obj_res.IsErr()) {\n    return TResult::Error(arguments_obj_res.UnwrapErr());\n  }\n  std::optional<tvm::ffi::json::Object> arguments_obj = arguments_obj_res.Unwrap();\n  if (arguments_obj.has_value()) {\n    std::unordered_map<std::string, std::string> arguments;\n    arguments.reserve(arguments_obj.value().size());\n    for (const auto& [key, value] : arguments_obj.value()) {\n      arguments[key.cast<tvm::ffi::String>()] = tvm::ffi::json::Stringify(value);\n    }\n    chat_func_call.arguments = std::move(arguments);\n  }\n\n  return TResult::Ok(chat_func_call);\n}\n\ntvm::ffi::json::Object ChatFunctionCall::AsJSON() const {\n  tvm::ffi::json::Object obj;\n  tvm::ffi::json::Object arguments_obj;\n  if (this->arguments.has_value()) {\n    for (const auto& pair : this->arguments.value()) {\n      arguments_obj.Set(pair.first, pair.second);\n    }\n    obj.Set(\"arguments\", arguments_obj);\n  }\n\n  obj.Set(\"name\", this->name);\n  return obj;\n}\n\nResult<ChatToolCall> ChatToolCall::FromJSON(const tvm::ffi::json::Object& json_obj) {\n  using TResult = Result<ChatToolCall>;\n  ChatToolCall chat_tool_call;\n\n  // function\n  Result<tvm::ffi::json::Object> function_obj_res =\n      json::LookupWithResultReturn<tvm::ffi::json::Object>(json_obj, \"function\");\n  if (function_obj_res.IsErr()) {\n    return TResult::Error(function_obj_res.UnwrapErr());\n  }\n  Result<ChatFunctionCall> function_res = ChatFunctionCall::FromJSON(function_obj_res.Unwrap());\n  if (function_res.IsErr()) {\n    return TResult::Error(function_res.UnwrapErr());\n  }\n  chat_tool_call.function = function_res.Unwrap();\n\n  // overwrite default id\n  Result<std::optional<std::string>> id_res =\n      json::LookupOptionalWithResultReturn<std::string>(json_obj, \"id\");\n  if (id_res.IsErr()) {\n    return TResult::Error(id_res.UnwrapErr());\n  }\n  std::optional<std::string> id = id_res.UnwrapErr();\n  if (id.has_value()) {\n    chat_tool_call.id = id.value();\n  }\n\n  return TResult::Ok(chat_tool_call);\n}\n\ntvm::ffi::json::Object ChatToolCall::AsJSON() const {\n  tvm::ffi::json::Object obj;\n  obj.Set(\"id\", this->id);\n  obj.Set(\"function\", this->function.AsJSON());\n  obj.Set(\"type\", \"function\");\n  return obj;\n}\n\nResult<ChatCompletionMessage> ChatCompletionMessage::FromJSON(\n    const tvm::ffi::json::Object& json_obj) {\n  using TResult = Result<ChatCompletionMessage>;\n  ChatCompletionMessage message;\n  ChatCompletionMessageContent content;\n\n  // content\n  if (json_obj.count(\"content\") == 0) {\n    return TResult::Error(\"ValueError: key \\\"content\\\" not found in the chat completion.\");\n  }\n  tvm::ffi::json::Value content_val = json_obj.at(\"content\");\n  if (content_val.try_cast<std::string>().has_value()) {\n    content = content_val.cast<std::string>();\n  } else if (content_val == nullptr) {\n    // skip\n  } else {\n    // most complicated case\n    std::vector<std::unordered_map<std::string, std::string>> parts;\n    Result<tvm::ffi::json::Array> content_arr_res =\n        json::LookupWithResultReturn<tvm::ffi::json::Array>(json_obj, \"content\");\n    if (content_arr_res.IsErr()) {\n      return TResult::Error(content_arr_res.UnwrapErr());\n    }\n    tvm::ffi::json::Array content_arr = content_arr_res.Unwrap();\n    for (const auto& item : content_arr) {\n      if (!item.try_cast<tvm::ffi::json::Object>().has_value()) {\n        return TResult::Error(\"The content of chat completion message is not an object\");\n      }\n      tvm::ffi::json::Object item_obj = item.cast<tvm::ffi::json::Object>();\n      std::unordered_map<std::string, std::string> item_map;\n      for (const auto& [key, value] : item_obj) {\n        item_map[key.cast<tvm::ffi::String>()] = tvm::ffi::json::Stringify(value);\n      }\n      parts.push_back(std::move(item_map));\n    }\n    content = parts;\n  }\n  message.content = content;\n\n  // role\n  Result<std::string> role_str_res = json::LookupWithResultReturn<std::string>(json_obj, \"role\");\n  if (role_str_res.IsErr()) {\n    return TResult::Error(role_str_res.UnwrapErr());\n  }\n  std::string role_str = role_str_res.Unwrap();\n  if (role_str == \"system\" || role_str == \"user\" || role_str == \"assistant\" || role_str == \"tool\") {\n    message.role = role_str;\n  } else {\n    return TResult::Error(\"Invalid role in chat completion message: \" + role_str);\n  }\n\n  // name\n  Result<std::optional<std::string>> name_res =\n      json::LookupOptionalWithResultReturn<std::string>(json_obj, \"name\");\n  if (name_res.IsErr()) {\n    return TResult::Error(name_res.UnwrapErr());\n  }\n  message.name = name_res.Unwrap();\n\n  // tool calls\n  Result<std::optional<tvm::ffi::json::Array>> tool_calls_arr_res =\n      json::LookupOptionalWithResultReturn<tvm::ffi::json::Array>(json_obj, \"tool_calls\");\n  if (tool_calls_arr_res.IsErr()) {\n    return TResult::Error(tool_calls_arr_res.UnwrapErr());\n  }\n  std::optional<tvm::ffi::json::Array> tool_calls_arr = tool_calls_arr_res.Unwrap();\n  if (tool_calls_arr.has_value()) {\n    std::vector<ChatToolCall> tool_calls;\n    tool_calls.reserve(tool_calls_arr.value().size());\n    for (const auto& item : tool_calls_arr.value()) {\n      if (!item.try_cast<tvm::ffi::json::Object>().has_value()) {\n        return TResult::Error(\"A tool call item in the chat completion message is not an object\");\n      }\n      Result<ChatToolCall> tool_call = ChatToolCall::FromJSON(item.cast<tvm::ffi::json::Object>());\n      if (tool_call.IsErr()) {\n        return TResult::Error(tool_call.UnwrapErr());\n      }\n      tool_calls.push_back(tool_call.Unwrap());\n    }\n    message.tool_calls = tool_calls;\n  }\n\n  // tool call id\n  Result<std::optional<std::string>> tool_call_id_res =\n      json::LookupOptionalWithResultReturn<std::string>(json_obj, \"tool_call_id\");\n  if (tool_call_id_res.IsErr()) {\n    return TResult::Error(tool_call_id_res.UnwrapErr());\n  }\n  message.tool_call_id = tool_call_id_res.Unwrap();\n\n  return TResult::Ok(message);\n}\n\nResult<ChatCompletionRequest> ChatCompletionRequest::FromJSON(const std::string& json_str) {\n  using TResult = Result<ChatCompletionRequest>;\n  Result<tvm::ffi::json::Object> json_obj_res = json::ParseToJSONObjectWithResultReturn(json_str);\n  if (json_obj_res.IsErr()) {\n    return TResult::Error(json_obj_res.UnwrapErr());\n  }\n  tvm::ffi::json::Object json_obj = json_obj_res.Unwrap();\n  ChatCompletionRequest request;\n\n  // messages\n  Result<tvm::ffi::json::Array> messages_arr_res =\n      json::LookupWithResultReturn<tvm::ffi::json::Array>(json_obj, \"messages\");\n  if (messages_arr_res.IsErr()) {\n    return TResult::Error(messages_arr_res.UnwrapErr());\n  }\n  std::vector<ChatCompletionMessage> messages;\n  tvm::ffi::json::Array messages_arr = messages_arr_res.Unwrap();\n  for (const auto& item : messages_arr) {\n    if (!item.try_cast<tvm::ffi::json::Object>().has_value()) {\n      return TResult::Error(\"A message in chat completion request is not object\");\n    }\n    tvm::ffi::json::Object item_obj = item.cast<tvm::ffi::json::Object>();\n    Result<ChatCompletionMessage> message = ChatCompletionMessage::FromJSON(item_obj);\n    if (message.IsErr()) {\n      return TResult::Error(message.UnwrapErr());\n    }\n    messages.push_back(message.Unwrap());\n  }\n  request.messages = messages;\n\n  // model\n  Result<std::optional<std::string>> model_res =\n      json::LookupOptionalWithResultReturn<std::string>(json_obj, \"model\");\n  if (model_res.IsErr()) {\n    return TResult::Error(model_res.UnwrapErr());\n  }\n  request.model = model_res.Unwrap();\n\n  // temperature\n  Result<std::optional<double>> temperature_res =\n      json::LookupOptionalWithResultReturn<double>(json_obj, \"temperature\");\n  if (temperature_res.IsErr()) {\n    return TResult::Error(temperature_res.UnwrapErr());\n  }\n  request.temperature = temperature_res.Unwrap();\n  // top_p\n  Result<std::optional<double>> top_p_res =\n      json::LookupOptionalWithResultReturn<double>(json_obj, \"top_p\");\n  if (top_p_res.IsErr()) {\n    return TResult::Error(top_p_res.UnwrapErr());\n  }\n  request.top_p = top_p_res.Unwrap();\n  // max_tokens\n  Result<std::optional<int64_t>> max_tokens_res =\n      json::LookupOptionalWithResultReturn<int64_t>(json_obj, \"max_tokens\");\n  if (max_tokens_res.IsErr()) {\n    return TResult::Error(max_tokens_res.UnwrapErr());\n  }\n  request.max_tokens = max_tokens_res.Unwrap();\n  // n\n  Result<int64_t> n_res = json::LookupOrDefaultWithResultReturn<int64_t>(json_obj, \"n\", 1);\n  if (n_res.IsErr()) {\n    return TResult::Error(n_res.UnwrapErr());\n  }\n  request.n = n_res.Unwrap();\n  // frequency_penalty\n  Result<std::optional<double>> frequency_penalty_res =\n      json::LookupOptionalWithResultReturn<double>(json_obj, \"frequency_penalty\");\n  if (frequency_penalty_res.IsErr()) {\n    return TResult::Error(frequency_penalty_res.UnwrapErr());\n  }\n  request.frequency_penalty = frequency_penalty_res.Unwrap();\n  // presence_penalty\n  Result<std::optional<double>> presence_penalty_res =\n      json::LookupOptionalWithResultReturn<double>(json_obj, \"presence_penalty\");\n  if (presence_penalty_res.IsErr()) {\n    return TResult::Error(presence_penalty_res.UnwrapErr());\n  }\n  request.presence_penalty = presence_penalty_res.Unwrap();\n  // seed\n  Result<std::optional<int64_t>> seed_res =\n      json::LookupOptionalWithResultReturn<int64_t>(json_obj, \"seed\");\n  if (seed_res.IsErr()) {\n    return TResult::Error(seed_res.UnwrapErr());\n  }\n  request.seed = seed_res.Unwrap();\n\n  // stop strings\n  Result<std::optional<tvm::ffi::json::Array>> stop_strs_res =\n      json::LookupOptionalWithResultReturn<tvm::ffi::json::Array>(json_obj, \"stop\");\n  if (stop_strs_res.IsErr()) {\n    return TResult::Error(stop_strs_res.UnwrapErr());\n  }\n  std::optional<tvm::ffi::json::Array> stop_strs = stop_strs_res.Unwrap();\n  if (stop_strs.has_value()) {\n    std::vector<std::string> stop;\n    for (const auto& stop_str_value : stop_strs.value()) {\n      if (!stop_str_value.try_cast<std::string>().has_value()) {\n        return TResult::Error(\"One given value in field \\\"stop\\\" is not a string.\");\n      }\n      stop.push_back(stop_str_value.cast<std::string>());\n    }\n    request.stop = std::move(stop);\n  }\n\n  // tool_choice\n  Result<std::string> tool_choice_res =\n      json::LookupOrDefaultWithResultReturn<std::string>(json_obj, \"tool_choice\", \"auto\");\n  if (tool_choice_res.IsErr()) {\n    return TResult::Error(tool_choice_res.UnwrapErr());\n  }\n  request.tool_choice = tool_choice_res.Unwrap();\n\n  // tools\n  Result<std::optional<tvm::ffi::json::Array>> tools_arr_res =\n      json::LookupOptionalWithResultReturn<tvm::ffi::json::Array>(json_obj, \"tools\");\n  if (tool_choice_res.IsErr()) {\n    return TResult::Error(tool_choice_res.UnwrapErr());\n  }\n  std::optional<tvm::ffi::json::Array> tools_arr = tools_arr_res.Unwrap();\n  if (tools_arr.has_value()) {\n    std::vector<ChatTool> tools;\n    tools.reserve(tools_arr.value().size());\n    for (const auto& item : tools_arr.value()) {\n      if (!item.try_cast<tvm::ffi::json::Object>().has_value()) {\n        return TResult::Error(\"A tool of the chat completion request is not an object\");\n      }\n      Result<ChatTool> tool = ChatTool::FromJSON(item.cast<tvm::ffi::json::Object>());\n      if (tool.IsErr()) {\n        return TResult::Error(tool.UnwrapErr());\n      }\n      tools.push_back(tool.Unwrap());\n    }\n    request.tools = tools;\n  }\n\n  // response format\n  std::optional<tvm::ffi::json::Object> response_format_obj =\n      json::LookupOptional<tvm::ffi::json::Object>(json_obj, \"response_format\");\n  if (response_format_obj.has_value()) {\n    Result<ResponseFormat> response_format_res =\n        ResponseFormat::FromJSON(response_format_obj.value());\n    if (response_format_res.IsErr()) {\n      return TResult::Error(response_format_res.UnwrapErr());\n    }\n    request.response_format = response_format_res.Unwrap();\n  }\n\n  // debug_config\n  Result<std::optional<tvm::ffi::json::Object>> debug_config_opt_res =\n      json::LookupOptionalWithResultReturn<tvm::ffi::json::Object>(json_obj, \"debug_config\");\n  if (debug_config_opt_res.IsErr()) {\n    return TResult::Error(debug_config_opt_res.UnwrapErr());\n  }\n  auto debug_config_opt = debug_config_opt_res.Unwrap();\n  if (debug_config_opt.has_value()) {\n    Result<DebugConfig> debug_config_res = DebugConfig::FromJSON(debug_config_opt.value());\n    if (debug_config_res.IsErr()) {\n      return TResult::Error(debug_config_res.UnwrapErr());\n    }\n    request.debug_config = debug_config_res.Unwrap();\n  }\n\n  // TODO: Other parameters\n  return TResult::Ok(request);\n}\n\ntvm::ffi::json::Object ChatCompletionMessage::AsJSON() const {\n  tvm::ffi::json::Object obj;\n\n  if (this->content.IsText()) {\n    obj.Set(\"content\", this->content.Text());\n  } else if (this->content.IsParts()) {\n    tvm::ffi::json::Array content_arr;\n    for (const auto& item : this->content.Parts()) {\n      tvm::ffi::json::Object item_obj;\n      for (const auto& pair : item) {\n        item_obj.Set(pair.first, pair.second);\n      }\n      content_arr.push_back(item_obj);\n    }\n    obj.Set(\"content\", content_arr);\n  }\n\n  obj.Set(\"role\", this->role);\n\n  if (this->name.has_value()) {\n    obj.Set(\"name\", this->name.value());\n  }\n  if (this->tool_call_id.has_value()) {\n    obj.Set(\"tool_call_id\", this->tool_call_id.value());\n  }\n  if (this->tool_calls.has_value()) {\n    tvm::ffi::json::Array tool_calls_arr;\n    for (const auto& tool_call : this->tool_calls.value()) {\n      tool_calls_arr.push_back(tool_call.AsJSON());\n    }\n    obj.Set(\"tool_calls\", tool_calls_arr);\n  }\n  return obj;\n}\n\ntvm::ffi::json::Object ChatCompletionResponseChoice::AsJSON() const {\n  tvm::ffi::json::Object obj;\n  if (!this->finish_reason.has_value()) {\n    obj.Set(\"finish_reason\", nullptr);\n  } else {\n    if (this->finish_reason == FinishReason::stop) {\n      obj.Set(\"finish_reason\", \"stop\");\n    } else if (this->finish_reason == FinishReason::length) {\n      obj.Set(\"finish_reason\", \"length\");\n    } else if (this->finish_reason == FinishReason::tool_calls) {\n      obj.Set(\"finish_reason\", \"tool_calls\");\n    } else if (this->finish_reason == FinishReason::error) {\n      obj.Set(\"finish_reason\", \"error\");\n    }\n  }\n  obj.Set(\"index\", static_cast<int64_t>(this->index));\n  obj.Set(\"message\", this->message.AsJSON());\n  return obj;\n}\n\ntvm::ffi::json::Object ChatCompletionStreamResponseChoice::AsJSON() const {\n  tvm::ffi::json::Object obj;\n  if (!this->finish_reason.has_value()) {\n    obj.Set(\"finish_reason\", nullptr);\n  } else {\n    if (this->finish_reason.value() == FinishReason::stop) {\n      obj.Set(\"finish_reason\", \"stop\");\n    } else if (this->finish_reason.value() == FinishReason::length) {\n      obj.Set(\"finish_reason\", \"length\");\n    } else if (this->finish_reason.value() == FinishReason::tool_calls) {\n      obj.Set(\"finish_reason\", \"tool_calls\");\n    } else if (this->finish_reason.value() == FinishReason::error) {\n      obj.Set(\"finish_reason\", \"error\");\n    }\n  }\n\n  obj.Set(\"index\", static_cast<int64_t>(this->index));\n  obj.Set(\"delta\", this->delta.AsJSON());\n  return obj;\n}\n\ntvm::ffi::json::Object ChatCompletionResponse::AsJSON() const {\n  tvm::ffi::json::Object obj;\n  obj.Set(\"id\", this->id);\n  tvm::ffi::json::Array choices_arr;\n  for (const auto& choice : this->choices) {\n    choices_arr.push_back(choice.AsJSON());\n  }\n  obj.Set(\"choices\", choices_arr);\n  obj.Set(\"created\", static_cast<int64_t>(this->created));\n  obj.Set(\"model\", this->model);\n  obj.Set(\"system_fingerprint\", this->system_fingerprint);\n  obj.Set(\"object\", this->object);\n  return obj;\n}\n\ntvm::ffi::json::Object ChatCompletionStreamResponse::AsJSON() const {\n  tvm::ffi::json::Object obj;\n  obj.Set(\"id\", this->id);\n\n  tvm::ffi::json::Array choices_arr;\n  for (const auto& choice : this->choices) {\n    choices_arr.push_back(choice.AsJSON());\n  }\n  obj.Set(\"choices\", choices_arr);\n\n  obj.Set(\"created\", static_cast<int64_t>(this->created));\n  obj.Set(\"model\", this->model);\n  obj.Set(\"system_fingerprint\", this->system_fingerprint);\n  obj.Set(\"object\", this->object);\n  if (usage.has_value()) {\n    obj.Set(\"usage\", usage.value());\n  }\n  return obj;\n}\n\n}  // namespace json_ffi\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/json_ffi/openai_api_protocol.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file json_ffi/openai_api_protocol.h\n * \\brief The header of OpenAI API Protocol in MLC LLM.\n */\n#ifndef MLC_LLM_JSON_FFI_OPENAI_API_PROTOCOL_H\n#define MLC_LLM_JSON_FFI_OPENAI_API_PROTOCOL_H\n\n#include <tvm/ffi/extra/json.h>\n\n#include <ctime>\n#include <optional>\n#include <random>\n#include <string>\n#include <unordered_map>\n#include <vector>\n\n#include \"../serve/config.h\"\n#include \"../support/result.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace json_ffi {\n\nusing serve::DebugConfig;\nusing serve::ResponseFormat;\n\nenum class Type { text, json_object, function };\nenum class FinishReason { stop, length, tool_calls, error };\n\ninline std::string GenerateUUID(size_t length) {\n  auto randchar = []() -> char {\n    const char charset[] =\n        \"0123456789\"\n        \"ABCDEFGHIJKLMNOPQRSTUVWXYZ\"\n        \"abcdefghijklmnopqrstuvwxyz\";\n    const size_t max_index = (sizeof(charset) - 1);\n    return charset[rand() % max_index];\n  };\n  std::string str(length, 0);\n  std::generate_n(str.begin(), length, randchar);\n  return str;\n}\n\nclass ChatFunction {\n public:\n  std::optional<std::string> description = std::nullopt;\n  std::string name;\n  // Todo: change to std::vector<std::pair<std::string, std::string>>?\n  std::unordered_map<std::string, std::string>\n      parameters;  // Assuming parameters are string key-value pairs\n\n  static Result<ChatFunction> FromJSON(const tvm::ffi::json::Object& json);\n  tvm::ffi::json::Object AsJSON() const;\n};\n\nclass ChatTool {\n public:\n  Type type = Type::function;\n  ChatFunction function;\n\n  static Result<ChatTool> FromJSON(const tvm::ffi::json::Object& json);\n  tvm::ffi::json::Object AsJSON() const;\n};\n\nclass ChatFunctionCall {\n public:\n  std::string name;\n  std::optional<std::unordered_map<std::string, std::string>> arguments =\n      std::nullopt;  // Assuming arguments are string key-value pairs\n\n  static Result<ChatFunctionCall> FromJSON(const tvm::ffi::json::Object& json);\n  tvm::ffi::json::Object AsJSON() const;\n};\n\nclass ChatToolCall {\n public:\n  std::string id = \"call_\" + GenerateUUID(8);\n  Type type = Type::function;\n  ChatFunctionCall function;\n\n  static Result<ChatToolCall> FromJSON(const tvm::ffi::json::Object& json);\n  tvm::ffi::json::Object AsJSON() const;\n};\n\nclass ChatCompletionMessageContent {\n public:\n  ChatCompletionMessageContent() = default;\n\n  ChatCompletionMessageContent(std::nullopt_t) {}  // NOLINT(*)\n\n  ChatCompletionMessageContent(std::string text) : text_(text) {}  // NOLINT(*)\n\n  ChatCompletionMessageContent(\n      std::vector<std::unordered_map<std::string, std::string>> parts)  // NOLINT(*)\n      : parts_(parts) {}\n\n  bool IsNull() const { return !IsText() && !IsParts(); }\n\n  bool IsText() const { return text_.operator bool(); }\n\n  bool IsParts() const { return parts_.operator bool(); }\n\n  const std::string& Text() const { return text_.value(); }\n\n  const std::vector<std::unordered_map<std::string, std::string>>& Parts() const {\n    return parts_.value();\n  }\n\n private:\n  /*! \\brief used to store text content */\n  std::optional<std::string> text_;\n  std::optional<std::vector<std::unordered_map<std::string, std::string>>> parts_;\n};\n\nclass ChatCompletionMessage {\n public:\n  ChatCompletionMessageContent content =\n      std::nullopt;  // Assuming content is a list of string key-value pairs\n  std::string role;\n  std::optional<std::string> name = std::nullopt;\n  std::optional<std::vector<ChatToolCall>> tool_calls = std::nullopt;\n  std::optional<std::string> tool_call_id = std::nullopt;\n\n  static Result<ChatCompletionMessage> FromJSON(const tvm::ffi::json::Object& json);\n  tvm::ffi::json::Object AsJSON() const;\n};\n\nclass ChatCompletionRequest {\n public:\n  std::vector<ChatCompletionMessage> messages;\n  std::optional<std::string> model = std::nullopt;\n  std::optional<double> frequency_penalty = std::nullopt;\n  std::optional<double> presence_penalty = std::nullopt;\n  bool logprobs = false;\n  int top_logprobs = 0;\n  std::optional<std::vector<std::pair<int, float>>> logit_bias = std::nullopt;\n  std::optional<int> max_tokens = std::nullopt;\n  int n = 1;\n  std::optional<int> seed = std::nullopt;\n  std::optional<std::vector<std::string>> stop = std::nullopt;\n  bool stream = false;\n  std::optional<double> temperature = std::nullopt;\n  std::optional<double> top_p = std::nullopt;\n  std::optional<std::vector<ChatTool>> tools = std::nullopt;\n  std::optional<std::string> tool_choice = std::nullopt;\n  std::optional<std::string> user = std::nullopt;\n  bool ignore_eos = false;\n  std::optional<ResponseFormat> response_format = std::nullopt;\n  std::optional<DebugConfig> debug_config = std::nullopt;\n\n  /*! \\brief Parse and create a ChatCompletionRequest instance from the given JSON string. */\n  static Result<ChatCompletionRequest> FromJSON(const std::string& json_str);\n\n  // TODO: check_penalty_range, check_logit_bias, check_logprobs\n};\n\nclass ChatCompletionResponseChoice {\n public:\n  std::optional<FinishReason> finish_reason;\n  int index = 0;\n  ChatCompletionMessage message;\n  // TODO: logprobs\n\n  tvm::ffi::json::Object AsJSON() const;\n};\n\nclass ChatCompletionStreamResponseChoice {\n public:\n  std::optional<FinishReason> finish_reason;\n  int index = 0;\n  ChatCompletionMessage delta;\n  // TODO: logprobs\n\n  tvm::ffi::json::Object AsJSON() const;\n};\n\nclass ChatCompletionResponse {\n public:\n  std::string id;\n  std::vector<ChatCompletionResponseChoice> choices;\n  int created = static_cast<int>(std::time(nullptr));\n  std::string model;\n  std::string system_fingerprint;\n  std::string object = \"chat.completion\";\n  // TODO: usage_info\n\n  tvm::ffi::json::Object AsJSON() const;\n};\n\nclass ChatCompletionStreamResponse {\n public:\n  std::string id;\n  std::vector<ChatCompletionStreamResponseChoice> choices;\n  int created = static_cast<int>(std::time(nullptr));\n  std::string model;\n  std::string system_fingerprint;\n  std::string object = \"chat.completion.chunk\";\n  std::optional<tvm::ffi::json::Value> usage;\n\n  tvm::ffi::json::Object AsJSON() const;\n};\n\n}  // namespace json_ffi\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_JSON_FFI_OPENAI_API_PROTOCOL_H\n"
  },
  {
    "path": "cpp/metadata/model.cc",
    "content": "#include \"./model.h\"\n\n#include <unordered_map>\n\n#include \"../support/json_parser.h\"\n\nnamespace mlc {\nnamespace llm {\n\nusing namespace tvm::runtime;\nusing tvm::ffi::Function;\nusing tvm::ffi::Optional;\n\nModelMetadata::Param::Preproc ModelMetadata::Param::Preproc::FromJSON(\n    const tvm::ffi::json::Object& js, const tvm::ffi::json::Object& model_config) {\n  Preproc preproc;\n  TVM_FFI_ICHECK_GE(js.size(), 3) << \"ValueError: Invalid preprocessing info in JSON\";\n  preproc.func_name = json::Lookup<std::string>(js, \"func_name\");\n  json::SymShapeTuple sym_out_shape = json::Lookup<json::SymShapeTuple>(js, \"out_shape\");\n  preproc.out_shape = sym_out_shape.ToStatic(model_config);\n  json::SymShapeTuple sym_in_shape =\n      json::LookupOrDefault<json::SymShapeTuple>(js, \"in_shape\", sym_out_shape);\n  preproc.in_shape = sym_in_shape.ToStatic(model_config);\n  preproc.out_dtype = json::Lookup<DataType>(js, \"out_dtype\");\n  return preproc;\n}\n\nModelMetadata::Param ModelMetadata::Param::FromJSON(const tvm::ffi::json::Object& param,\n                                                    const tvm::ffi::json::Object& model_config) {\n  Param result;\n  result.name = json::Lookup<std::string>(param, \"name\");\n  result.dtype = json::Lookup<DataType>(param, \"dtype\");\n  // A shape being `-1` means that it is dynamic\n  json::SymShapeTuple sym_shape = json::Lookup<json::SymShapeTuple>(param, \"shape\");\n  result.shape = sym_shape.ToStatic(model_config);\n  // - \"preproc\"\n  tvm::ffi::json::Array preprocs = json::Lookup<tvm::ffi::json::Array>(param, \"preprocs\");\n  result.preprocs.reserve(preprocs.size());\n  for (int i = 0; i < preprocs.size(); i++) {\n    result.preprocs.emplace_back(ModelMetadata::Param::Preproc::FromJSON(\n        json::Lookup<tvm::ffi::json::Object>(preprocs, i), model_config));\n  }\n  // - \"pipeline_stages\"\n  int pipeline_parallel_stages =\n      json::LookupOrDefault<int64_t>(model_config, \"pipeline_parallel_stages\", 1);\n  std::optional<tvm::ffi::json::Array> opt_pipeline_stages =\n      json::LookupOptional<tvm::ffi::json::Array>(param, \"pipeline_stages\");\n  if (pipeline_parallel_stages > 1) {\n    TVM_FFI_ICHECK(opt_pipeline_stages.has_value())\n        << \"The pipeline stage is undefined for parameter \\\"\" << result.name\n        << \"\\\" when the number of pipeline parallel stages is \" << pipeline_parallel_stages;\n  }\n  if (opt_pipeline_stages.has_value()) {\n    result.pipeline_stages.reserve(opt_pipeline_stages.value().size());\n    for (const tvm::ffi::json::Value& v : opt_pipeline_stages.value()) {\n      auto int_opt = v.try_cast<int64_t>();\n      TVM_FFI_ICHECK(int_opt.has_value()) << \"Pipeline stage is not a integer.\";\n      result.pipeline_stages.push_back(*int_opt);\n    }\n  } else {\n    result.pipeline_stages = {0};\n  }\n  return result;\n}\n\nModelMetadata::KVCacheMetadata ModelMetadata::KVCacheMetadata::FromJSON(\n    const tvm::ffi::json::Object& json) {\n  KVCacheMetadata kv_cache_metadata;\n  kv_cache_metadata.num_hidden_layers = json::Lookup<int64_t>(json, \"num_hidden_layers\");\n  kv_cache_metadata.head_dim = json::Lookup<int64_t>(json, \"head_dim\");\n  kv_cache_metadata.num_attention_heads = json::Lookup<int64_t>(json, \"num_attention_heads\");\n  kv_cache_metadata.num_key_value_heads = json::Lookup<int64_t>(json, \"num_key_value_heads\");\n  return kv_cache_metadata;\n}\n\nModelMetadata ModelMetadata::FromJSON(const tvm::ffi::json::Object& metadata,\n                                      const tvm::ffi::json::Object& model_config) {\n  ModelMetadata result;\n  result.model_type = json::Lookup<std::string>(metadata, \"model_type\");\n  result.quantization = json::Lookup<std::string>(metadata, \"quantization\");\n  result.context_window_size = json::Lookup<int64_t>(metadata, \"context_window_size\");\n  result.prefill_chunk_size = json::Lookup<int64_t>(metadata, \"prefill_chunk_size\");\n  result.max_batch_size = json::Lookup<int64_t>(metadata, \"max_batch_size\");\n  if (metadata.count(\"sliding_window_size\"))\n    result.sliding_window_size = json::Lookup<int64_t>(metadata, \"sliding_window_size\");\n  if (metadata.count(\"sliding_window\"))  // to be removed after SLM migration\n    result.sliding_window_size = json::Lookup<int64_t>(metadata, \"sliding_window\");\n  if (metadata.count(\"attention_sink_size\"))  // remove after sink is decoupled from model lib\n    result.attention_sink_size = json::Lookup<int64_t>(metadata, \"attention_sink_size\");\n  result.seqlen_padding_factor =\n      json::LookupOrDefault<int64_t>(metadata, \"seqlen_padding_factor\", 1);\n  result.tensor_parallel_shards = json::Lookup<int64_t>(metadata, \"tensor_parallel_shards\");\n  result.pipeline_parallel_stages =\n      json::LookupOrDefault<int64_t>(metadata, \"pipeline_parallel_stages\", 1);\n  result.disaggregation = json::LookupOrDefault<bool>(metadata, \"disaggregation\", false);\n  result.model_task = json::LookupOrDefault<std::string>(metadata, \"model_task\", \"chat\");\n  if (metadata.count(\"embedding_metadata\")) {\n    tvm::ffi::json::Object emb =\n        json::Lookup<tvm::ffi::json::Object>(metadata, \"embedding_metadata\");\n    result.embedding_model_type = json::LookupOrDefault<std::string>(emb, \"model_type\", \"\");\n    result.embedding_pooling_strategy =\n        json::LookupOrDefault<std::string>(emb, \"pooling_strategy\", \"\");\n    result.embedding_normalize = json::LookupOrDefault<bool>(emb, \"normalize\", false);\n  }\n  result.kv_state_kind = KVStateKindFromString(\n      json::LookupOrDefault<std::string>(metadata, \"kv_state_kind\", \"kv_cache\"));\n  if (result.kv_state_kind != KVStateKind::kNone &&\n      result.kv_state_kind != KVStateKind::kRNNState) {\n    result.kv_cache_metadata =\n        KVCacheMetadata::FromJSON(json::Lookup<tvm::ffi::json::Object>(metadata, \"kv_cache\"));\n  } else {\n    result.kv_cache_metadata = {/*num_hidden_layers=*/0,\n                                /*head_dim=*/0,\n                                /*num_attention_heads=*/0,\n                                /*num_key_value_heads=*/0};\n  }\n  {\n    std::vector<ModelMetadata::Param>& params = result.params;\n    tvm::ffi::json::Array json_params = json::Lookup<tvm::ffi::json::Array>(metadata, \"params\");\n    params.reserve(json_params.size());\n    for (int i = 0, n = json_params.size(); i < n; ++i) {\n      params.emplace_back(ModelMetadata::Param::FromJSON(\n          json::Lookup<tvm::ffi::json::Object>(json_params, i), model_config));\n    }\n  }\n  {\n    std::unordered_map<std::string, int64_t>& memory_usage = result.memory_usage;\n    tvm::ffi::json::Object json_memory_usage =\n        json::Lookup<tvm::ffi::json::Object>(metadata, \"memory_usage\");\n    memory_usage.reserve(json_memory_usage.size());\n    for (const auto& [key, val] : json_memory_usage) {\n      std::string func_name = key.cast<tvm::ffi::String>();\n      memory_usage[func_name] = json::Lookup<int64_t>(json_memory_usage, func_name);\n    }\n  }\n  return result;\n}\n\nModelMetadata ModelMetadata::FromModule(Module module, const tvm::ffi::json::Object& model_config) {\n  std::string json_str = \"\";\n  Optional<Function> pf = module->GetFunction(\"_metadata\");\n  TVM_FFI_ICHECK(pf.defined()) << \"ValueError: _metadata function not found in module\";\n  json_str = pf.value()().cast<String>();\n  tvm::ffi::json::Object json = json::ParseToJSONObject(json_str);\n  try {\n    return ModelMetadata::FromJSON(json, model_config);\n  } catch (const std::exception& e) {\n    LOG(WARNING) << \"Failed to parse metadata:\\n\" << json_str << \"\\nerror: \" << e.what();\n    throw e;\n  }\n}\n\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/metadata/model.h",
    "content": "/*!\n * \\file model.h\n * \\brief Metadata stored in model lib\n */\n#ifndef MLC_LLM_CPP_MODEL_METADATA_H_\n#define MLC_LLM_CPP_MODEL_METADATA_H_\n\n#include <tvm/ffi/container/shape.h>\n#include <tvm/ffi/extra/json.h>\n#include <tvm/ffi/extra/module.h>\n#include <tvm/ffi/string.h>\n#include <tvm/runtime/data_type.h>\n#include <tvm/runtime/module.h>\n\n#include <unordered_map>\n\nnamespace mlc {\nnamespace llm {\n\nusing tvm::ffi::Module;\nusing tvm::ffi::Shape;\nusing tvm::ffi::String;\nusing tvm::runtime::DataType;\n\n/*! \\brief The kind of cache. */\nenum class KVStateKind : int {\n  kKVCache = 0,\n  kRNNState = 1,\n  kNone = 2,\n};\n\ninline std::string KVStateKindToString(KVStateKind kv_state_kind) {\n  if (kv_state_kind == KVStateKind::kKVCache) {\n    return \"kv_cache\";\n  } else if (kv_state_kind == KVStateKind::kRNNState) {\n    return \"rnn_state\";\n  } else if (kv_state_kind == KVStateKind::kNone) {\n    return \"none\";\n  } else {\n    LOG(FATAL) << \"Invalid kv state kind: \" << static_cast<int>(kv_state_kind);\n  }\n}\n\ninline KVStateKind KVStateKindFromString(const std::string& kv_state_kind) {\n  if (kv_state_kind == \"kv_cache\") {\n    return KVStateKind::kKVCache;\n  } else if (kv_state_kind == \"rnn_state\") {\n    return KVStateKind::kRNNState;\n  } else if (kv_state_kind == \"none\") {\n    return KVStateKind::kNone;\n  } else {\n    LOG(FATAL) << \"Invalid kv state kind string: \" << kv_state_kind;\n  }\n}\nstruct ModelMetadata {\n  struct Param {\n    struct Preproc {\n      String func_name;\n      Shape in_shape;\n      Shape out_shape;\n      DataType out_dtype;\n      static Preproc FromJSON(const tvm::ffi::json::Object& js,\n                              const tvm::ffi::json::Object& model_config);\n    };\n\n    String name;\n    Shape shape;\n    DataType dtype;\n    std::vector<Preproc> preprocs;\n    std::vector<int> pipeline_stages;\n    static Param FromJSON(const tvm::ffi::json::Object& param_obj,\n                          const tvm::ffi::json::Object& model_config);\n  };\n\n  struct KVCacheMetadata {\n    int64_t num_hidden_layers;\n    int64_t num_attention_heads;\n    int64_t num_key_value_heads;\n    int64_t head_dim;\n    static KVCacheMetadata FromJSON(const tvm::ffi::json::Object& json);\n  };\n\n  std::string model_type;\n  std::string quantization;\n  int64_t context_window_size;\n  int64_t prefill_chunk_size;\n  int64_t max_batch_size;\n  int64_t sliding_window_size;\n  int64_t tensor_parallel_shards;\n  int64_t pipeline_parallel_stages;\n  bool disaggregation;\n  int64_t attention_sink_size;\n  int64_t seqlen_padding_factor;\n  std::vector<Param> params;\n  std::unordered_map<std::string, int64_t> memory_usage;\n  KVStateKind kv_state_kind;\n  KVCacheMetadata kv_cache_metadata;\n  std::string model_task;\n  std::string embedding_model_type;\n  std::string embedding_pooling_strategy;\n  bool embedding_normalize = false;\n\n  static ModelMetadata FromJSON(const tvm::ffi::json::Object& json_str,\n                                const tvm::ffi::json::Object& model_config);\n  static ModelMetadata FromModule(Module module, const tvm::ffi::json::Object& model_config);\n};\n\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_CPP_MODEL_METADATA_H_\n"
  },
  {
    "path": "cpp/multi_gpu/builtin.cc",
    "content": "/*!\n * \\file builtin.cc\n * \\brief Multi-GPU builtin functions in MLC LLM.\n */\n#ifndef MLC_SINGLE_GPU_ONLY\n\n#include <tvm/ffi/container/array.h>\n#include <tvm/ffi/container/shape.h>\n#include <tvm/ffi/function.h>\n#include <tvm/ffi/optional.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/node/cast.h>\n#include <tvm/runtime/disco/builtin.h>\n#include <tvm/runtime/disco/disco_worker.h>\n#include <tvm/runtime/tensor.h>\n#include <tvm/runtime/vm/vm.h>\n\nnamespace mlc {\nnamespace llm {\nnamespace multi_gpu {\n\nusing namespace tvm::runtime;\nusing tvm::Downcast;\nusing tvm::ffi::Array;\nusing tvm::ffi::Optional;\nusing tvm::ffi::Shape;\n\nObjectRef DispatchFunctionByGroup(tvm::ffi::AnyView vm_arg,\n                                  Array<Array<ObjectRef>> funcs_and_args) {\n  using namespace vm;\n  VirtualMachine* vm = VirtualMachine::GetContextPtr(vm_arg);\n  DiscoWorker* worker = DiscoWorker::ThreadLocal();\n  int world_size = worker->num_workers;\n  int group_size = worker->num_workers / worker->num_groups;\n  int num_group = world_size / group_size;\n  TVM_FFI_ICHECK_EQ(funcs_and_args.size(), num_group)\n      << \"Number of groups mismatches. There are \" << num_group\n      << \" groups while the function/arg array has \" << funcs_and_args.size() << \" elements.\";\n\n  int group_id = worker->worker_id / group_size;\n  TVM_FFI_ICHECK(!funcs_and_args[group_id].empty())\n      << \"No function is provided for group \" << group_id;\n  VMClosure func = Downcast<VMClosure>(funcs_and_args[group_id][0]);\n\n  int num_args = static_cast<int>(funcs_and_args[group_id].size()) - 1;\n  std::vector<tvm::ffi::AnyView> packed_args(num_args);\n  for (int i = 0; i < num_args; ++i) {\n    // NOTE: Need explicily define `arg` so that the argument does not\n    // have type code kTVMObjectRValueRefArg.\n    packed_args[i] = funcs_and_args[group_id][1 + i];\n  }\n\n  tvm::ffi::Any rv;\n  vm->InvokeClosurePacked(Downcast<VMClosure>(funcs_and_args[group_id][0]),\n                          tvm::ffi::PackedArgs(packed_args.data(), packed_args.size()), &rv);\n  return rv.cast<ObjectRef>();\n}\n\nObjectRef SendFromLastGroupToWorker0(Tensor send, Optional<Tensor> recv, Shape shape,\n                                     DataType dtype) {\n  DiscoWorker* worker = DiscoWorker::ThreadLocal();\n  int worker_id = worker->worker_id;\n  int world_size = worker->num_workers;\n  int group_size = worker->num_workers / worker->num_groups;\n  TVM_FFI_ICHECK_NE(world_size, group_size) << \"Cannot perform when there is only one group.\";\n  int sender_id = world_size - group_size;\n  if (worker_id == 0) {\n    TVM_FFI_ICHECK(recv.defined()) << \"The receive Tensor is undefined for worker 0.\";\n    Tensor recv_arr = recv.value().CreateView(shape, dtype);\n    RecvFromWorker(recv_arr, sender_id);\n    return recv_arr;\n  } else if (worker_id == sender_id) {\n    TVM_FFI_ICHECK_EQ(DataType(send->dtype), dtype)\n        << \"The src Tensor has mismatched dtype than the expected dtype.\";\n    TVM_FFI_ICHECK_EQ(send->ndim, shape.size())\n        << \"The src Tensor has mismatched shape than the expected shape.\";\n    for (int i = 0; i < send->ndim; ++i) {\n      TVM_FFI_ICHECK_EQ(send->shape[i], shape[i])\n          << \"The src Tensor has mismatched shape than the expected shape.\";\n    }\n    SendToWorker(send, /*receiver_id=*/0);\n    return recv;\n  }\n\n  // We only process for worker 0 and the first worker of the last group.\n  // For other workers, we return the input object.\n  return recv;\n}\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  namespace refl = tvm::ffi::reflection;\n  refl::GlobalDef()\n      .def(\"mlc.multi_gpu.DispatchFunctionByGroup\", DispatchFunctionByGroup)\n      .def(\"mlc.multi_gpu.SendFromLastGroupToWorker0\", SendFromLastGroupToWorker0);\n}\n\n}  // namespace multi_gpu\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_SINGLE_GPU_ONLY\n"
  },
  {
    "path": "cpp/multi_gpu/multi_gpu_loader.cc",
    "content": "/*!\n * \\file multi_gpu_loader.cc\n * \\brief Implementation of a multi-GPU loader with loading-time sharding.\n */\n#ifndef MLC_SINGLE_GPU_ONLY\n#include <tvm/ffi/container/array.h>\n#include <tvm/ffi/extra/json.h>\n#include <tvm/ffi/function.h>\n#include <tvm/ffi/optional.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/runtime/device_api.h>\n#include <tvm/runtime/disco/builtin.h>\n#include <tvm/runtime/disco/disco_worker.h>\n#include <tvm/runtime/vm/tensor_cache_support.h>\n\n#include <chrono>\n#include <filesystem>\n#include <fstream>\n#include <functional>\n#include <numeric>\n#include <string>\n#include <thread>\n#include <unordered_map>\n#include <vector>\n\n#include \"../metadata/model.h\"\n#include \"../support/progress_bar.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace multi_gpu {\n\nusing tvm::Device;\nusing tvm::runtime::vm::TensorCacheMetadata;\nusing namespace tvm::runtime;\nusing tvm::ffi::Array;\nusing tvm::ffi::Function;\nusing tvm::ffi::Optional;\nusing tvm::ffi::TypedFunction;\nusing DurationType = std::chrono::microseconds;\n\nclass RangeTimer {\n public:\n  explicit RangeTimer(DurationType* result)\n      : start(std::chrono::high_resolution_clock::now()), result(result) {}\n\n  ~RangeTimer() {\n    std::chrono::time_point<std::chrono::high_resolution_clock> end =\n        std::chrono::high_resolution_clock::now();  //\n    auto duration = end - start;\n    (*result) += std::chrono::duration_cast<DurationType>(end - start);\n  }\n\n private:\n  std::chrono::time_point<std::chrono::high_resolution_clock> start;\n  DurationType* result;\n};\n\nclass PreprocessorPool {\n public:\n  explicit PreprocessorPool(const ModelMetadata& model_metadata, Module vm_module) {\n    for (const ModelMetadata::Param& param : model_metadata.params) {\n      for (const ModelMetadata::Param::Preproc& preproc : param.preprocs) {\n        const std::string& func_name = preproc.func_name;\n        if (Function f = vm_module.defined()\n                             ? vm_module->GetFunction(func_name, true).value_or(Function(nullptr))\n                             : nullptr;\n            f != nullptr) {\n          preproc_funcs[func_name] = f;\n        } else if (const auto f = Function::GetGlobal(func_name); f.has_value()) {\n          preproc_funcs[func_name] = *f;\n        } else {\n          LOG(FATAL) << \"ValueError: Undefined function: \" << func_name;\n        }\n      }\n    }\n  }\n\n  Tensor Apply(Tensor param, const ModelMetadata::Param& param_info) const {\n    for (const ModelMetadata::Param::Preproc& preproc : param_info.preprocs) {\n      const std::string& func_name = preproc.func_name;\n      Tensor param_in = param;\n      param = Tensor::Empty(preproc.out_shape, preproc.out_dtype, param->device);\n      TVM_FFI_ICHECK(preproc_funcs.count(func_name));\n      DLTensor dl_param_in = *param_in.operator->();\n      DLTensor dl_param = *param.operator->();\n      preproc_funcs.at(func_name)(&dl_param_in, &dl_param);\n    }\n    return param;\n  }\n\n private:\n  std::unordered_map<std::string, TypedFunction<void(DLTensor*, DLTensor*)>> preproc_funcs;\n};\n\nstruct ParamInfo {\n  const TensorCacheMetadata::FileRecord* file;\n  const TensorCacheMetadata::FileRecord::ParamRecord* param;\n};\n\nTensor RecvFromGlobalWorker0(Device device, const ModelMetadata::Param& param_info) {\n  Shape shape = param_info.preprocs.empty() ? param_info.shape : param_info.preprocs[0].in_shape;\n  Tensor result = Tensor::Empty(shape, param_info.dtype, device);\n  RecvFromWorker0(result);\n  return result;\n}\n\nTensor BroadcastOrShardAndScatter(Tensor param, const ModelMetadata::Param& param_info,\n                                  int num_shards, const PreprocessorPool& preprocs) {\n  bool needs_sharding = !param_info.preprocs.empty();\n  if (!needs_sharding) {\n    BroadcastFromWorker0(param, /*in_group=*/true, param);\n    return param;\n  }\n  Device device = param->device;\n  Shape shape = param_info.preprocs.back().out_shape;\n  DataType dtype = param_info.preprocs.back().out_dtype;\n  TVM_FFI_ICHECK(shape.size() >= 1 && shape[0] == num_shards)\n      << \"ValueError: The first dimension of the output shape must be equal to the \"\n      << \"number of shards, but got: \" << shape << \" and num_shards = \" << num_shards;\n  param = preprocs.Apply(param, param_info);\n  Tensor result = Tensor::Empty(Shape(shape.begin() + 1, shape.end()), dtype, device);\n  ScatterFromWorker0(param, /*in_group=*/true, result);\n  return result;\n}\n\nTensor ReceiveBroadcastedOrSharded(Device device, const ModelMetadata::Param& param_info,\n                                   int num_shards) {\n  bool needs_sharding = !param_info.preprocs.empty();\n  Tensor result;\n  if (needs_sharding) {\n    Shape shape = param_info.preprocs.back().out_shape;\n    DataType dtype = param_info.preprocs.back().out_dtype;\n    result = Tensor::Empty(Shape(shape.begin() + 1, shape.end()), dtype, device);\n    ScatterFromWorker0(std::nullopt, /*in_group=*/true, result);\n  } else {\n    result = Tensor::Empty(param_info.shape, param_info.dtype, device);\n    BroadcastFromWorker0(result, /*in_group=*/true, result);\n  }\n  return result;\n}\n\nstd::string FormatDuration(DurationType duration) {\n  std::ostringstream os;\n  auto float_seconds = std::chrono::duration_cast<std::chrono::duration<float>>(duration).count();\n  os << std::fixed << std::setprecision(3) << float_seconds << \" s\";\n  return os.str();\n}\n\nArray<Optional<Tensor>> LoadMultiGPU(const std::string& model_path, Module vm_module,\n                                     const std::string& model_config_str) {\n  DiscoWorker* worker = DiscoWorker::ThreadLocal();\n  Device device = worker->default_device;\n  int worker_id = worker->worker_id;\n  int group_size = worker->num_workers / worker->num_groups;\n  int num_shards = group_size;\n  int group_id = worker_id / group_size;\n  LOG(INFO) << \"[Worker #\" << worker_id << \"] Loading model to device: \" << device;\n  // Step 0. Initialize metadata and paths\n  TensorCacheMetadata tensor_cache_metadata = TensorCacheMetadata::Load(model_path);\n  tvm::ffi::json::Value model_config = tvm::ffi::json::Parse(model_config_str);\n  ModelMetadata model_metadata =\n      ModelMetadata::FromModule(vm_module, model_config.cast<tvm::ffi::json::Object>());\n  TVM_FFI_ICHECK_EQ(model_metadata.tensor_parallel_shards, num_shards)\n      << \"ValueError: The model is compiled using `--tensor-parallel-shards=\"\n      << model_metadata.tensor_parallel_shards\n      << \"`, but mlc-chat-config.json is configured to use \" << num_shards << \" GPUs. \"\n      << \"Please set \\\"tensor_parallel_shards\\\" in mlc-chat-config.json to \"\n      << model_metadata.tensor_parallel_shards;\n  // Step 1. Extract auxiliary information\n  PreprocessorPool preprocs(model_metadata, vm_module);\n  std::unordered_map<std::string, ModelMetadata::Param> param_name2info;\n  for (const ModelMetadata::Param& param : model_metadata.params) {\n    param_name2info[param.name] = param;\n  }\n  // Step 2. Load, preprocess and shard all the parameters\n  std::unordered_map<std::string, Tensor> sharded_params;\n  if (worker_id == 0) {\n    DurationType time_loading(0);\n    DurationType time_preproc(0);\n    ProgressBar progress_bar(model_metadata.params.size());\n    LOG(INFO) << \"Loading parameters...\";\n    for (const TensorCacheMetadata::FileRecord& record : tensor_cache_metadata.records) {\n      Array<Tensor> loaded_params;\n      {\n        RangeTimer _(&time_loading);\n        std::string raw_data_buffer;\n        loaded_params = record.Load(device, model_path, &raw_data_buffer);\n        DeviceAPI::Get(device)->StreamSync(device, nullptr);\n      }\n      // For each parameter in the shard file, preprocess and shard it\n      for (size_t i = 0; i < record.records.size(); ++i, progress_bar.Progress()) {\n        RangeTimer _(&time_preproc);\n        const std::string& param_name = record.records[i].name;\n        const ModelMetadata::Param& param_info = param_name2info.at(param_name);\n        for (int group_id : param_info.pipeline_stages) {\n          if (group_id == 0) {\n            // Broadcast or shard-scatter this parameter to all workers in worker group 0.\n            sharded_params[param_name] =\n                BroadcastOrShardAndScatter(loaded_params[i], param_info, num_shards, preprocs);\n          } else {\n            // Send this parameter to the first worker of the worker group of \"group_id\",\n            // and let that first worker to process this parameter.\n            SendToWorker(loaded_params[i], /*receiver_id=*/group_id * group_size);\n          }\n        }\n        DeviceAPI::Get(device)->StreamSync(device, nullptr);\n      }\n    }\n    LOG(INFO) << \"Loading done. Time used:\" << std::fixed << std::setprecision(3)  //\n              << \" Loading \" << FormatDuration(time_loading) << \" Preprocessing \"\n              << FormatDuration(time_preproc) << \".\";\n  } else {\n    for (const TensorCacheMetadata::FileRecord& record : tensor_cache_metadata.records) {\n      for (size_t i = 0; i < record.records.size(); ++i) {\n        const std::string& param_name = record.records[i].name;\n        const ModelMetadata::Param& param_info = param_name2info.at(param_name);\n        if (std::find(param_info.pipeline_stages.begin(), param_info.pipeline_stages.end(),\n                      group_id) == param_info.pipeline_stages.end()) {\n          // This worker group doesn't need to hold a copy of this parameter.\n          continue;\n        }\n\n        if (worker_id % group_size == 0) {\n          // The worker is the first worker of its worker group (while not the first worker group).\n          // Receive the full parameter from the global worker 0.\n          Tensor full_param = RecvFromGlobalWorker0(device, param_info);\n          // Broadcast or shard-scatter this parameter to all workers in its worker group.\n          sharded_params[param_name] =\n              BroadcastOrShardAndScatter(full_param, param_info, num_shards, preprocs);\n        } else {\n          // The worker is not the first worker of its worker group.\n          // Receive from the first worker in the its worker group.\n          sharded_params[param_name] = ReceiveBroadcastedOrSharded(device, param_info, num_shards);\n        }\n      }\n    }\n  }\n\n  // Step 3. Reorder the sharded parameters according to the order in model_metadata\n  Array<Optional<Tensor>> shards;\n  shards.reserve(model_metadata.params.size());\n  for (const ModelMetadata::Param& param : model_metadata.params) {\n    const auto& it = sharded_params.find(param.name);\n    shards.push_back(it == sharded_params.end() ? Optional<Tensor>() : it->second);\n  }\n  return shards;\n}\n\nArray<Optional<Tensor>> LoadMultiGPUPresharded(const std::string& model_path, Module vm_module,\n                                               const std::string& model_config_str) {\n  DiscoWorker* worker = DiscoWorker::ThreadLocal();\n  Device device = worker->default_device;\n  int worker_id = worker->worker_id;\n  int group_size = worker->num_workers / worker->num_groups;\n  int num_shards = group_size;\n  int group_id = worker_id / group_size;\n  int local_worker_id = worker_id % group_size;\n  LOG(INFO) << \"[Worker #\" << worker_id << \"] Loading model to device: \" << device;\n  // Step 0. Initialize metadata and paths\n  TensorCacheMetadata tensor_cache_metadata = TensorCacheMetadata::Load(model_path);\n  tvm::ffi::json::Value model_config = tvm::ffi::json::Parse(model_config_str);\n  ModelMetadata model_metadata =\n      ModelMetadata::FromModule(vm_module, model_config.cast<tvm::ffi::json::Object>());\n\n  std::unordered_map<std::string, ParamInfo> param_info_map;\n  for (const TensorCacheMetadata::FileRecord& file_record : tensor_cache_metadata.records) {\n    for (const TensorCacheMetadata::FileRecord::ParamRecord& param_record : file_record.records) {\n      const std::string& param_name = param_record.name;\n      param_info_map[param_name] = ParamInfo{&file_record, &param_record};\n    }\n  }\n\n  Array<Optional<Tensor>> params;\n  const TensorCacheMetadata::FileRecord* current_file_;\n  std::string current_file_stream_;\n  params.reserve(model_metadata.params.size());\n  DurationType time_loading(0);\n  for (const ModelMetadata::Param& param : model_metadata.params) {\n    RangeTimer _(&time_loading);\n    if (std::find(param.pipeline_stages.begin(), param.pipeline_stages.end(), group_id) ==\n        param.pipeline_stages.end()) {\n      // This worker group doesn't need to hold a copy of this parameter.\n      params.push_back(Optional<Tensor>());\n      continue;\n    }\n    bool needs_sharding = !param.preprocs.empty();\n    std::string param_name =\n        needs_sharding ? static_cast<const std::stringstream&>(\n                             std::stringstream() << param.name << \"_shard-\" << local_worker_id)\n                             .str()\n                       : std::string(param.name);\n    auto it = param_info_map.find(param_name);\n    TVM_FFI_ICHECK(it != param_info_map.end())\n        << \"ValueError: Cannot find parameter: \" << param_name;\n    const ParamInfo& param_info = (*it).second;\n    const TensorCacheMetadata::FileRecord::ParamRecord* param_record = param_info.param;\n    const TensorCacheMetadata::FileRecord* file_record = param_info.file;\n\n    if (file_record != current_file_) {\n      current_file_ = file_record;\n      file_record->Load(device, model_path, &current_file_stream_);\n    }\n\n    params.push_back(param_record->Load(device, &current_file_stream_));\n  }\n  SyncWorker();\n  if (worker_id == 0) {\n    LOG(INFO) << \"Loading done. Time used: \" << FormatDuration(time_loading) << \".\";\n  }\n  return params;\n}\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  namespace refl = tvm::ffi::reflection;\n  refl::GlobalDef()\n      .def(\"mlc.multi_gpu.LoadMultiGPU\", LoadMultiGPU)\n      .def(\"mlc.multi_gpu.LoadMultiGPUPresharded\", LoadMultiGPUPresharded);\n}\n\n}  // namespace multi_gpu\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_SINGLE_GPU_ONLY\n"
  },
  {
    "path": "cpp/serve/config.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/config.cc\n */\n#include \"config.h\"\n\n#include <tvm/ffi/function.h>\n#include <tvm/runtime/device_api.h>\n\n#include <limits>\n#include <random>\n\n#include \"../json_ffi/openai_api_protocol.h\"\n#include \"../support/json_parser.h\"\n#include \"../support/utils.h\"\n#include \"data.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  GenerationConfigNode::RegisterReflection();\n  EngineConfigNode::RegisterReflection();\n}\n\nuint64_t TotalDetectGlobalMemory(DLDevice device) {\n  // Get single-card GPU size.\n  tvm::ffi::Any rv;\n  DeviceAPI::Get(device)->GetAttr(device, DeviceAttrKind::kTotalGlobalMemory, &rv);\n  int64_t gpu_size_bytes = rv.cast<int64_t>();\n  // Since the memory size returned by the OpenCL runtime is smaller than the actual available\n  // memory space, we set a best available space so that MLC LLM can run 7B or 8B models on Android\n  // with OpenCL.\n  if (device.device_type == kDLOpenCL) {\n    int64_t min_size_bytes = 5LL * 1024 * 1024 * 1024;  //  Minimum size is 5 GB\n    gpu_size_bytes = std::max(gpu_size_bytes, min_size_bytes);\n  }\n  return gpu_size_bytes;\n}\n\n/****************** ResponseFormat ******************/\n\nResult<ResponseFormat> ResponseFormat::FromJSON(const tvm::ffi::json::Object& config) {\n  using TResult = Result<ResponseFormat>;\n  ResponseFormat res;\n  res.type = json::LookupOrDefault<std::string>(config, \"type\", \"text\");\n\n  std::optional<std::string> schema = json::LookupOptional<std::string>(config, \"schema\");\n  if (schema.has_value()) {\n    res.schema = schema.value();\n  }\n\n  if (res.type != \"text\" && res.type != \"function\" && res.type != \"json_object\") {\n    return TResult::Error(\"Uknonwn response_format type \" + res.type);\n  }\n\n  return TResult::Ok(res);\n}\n\ntvm::ffi::json::Object ResponseFormat::AsJSON() const {\n  tvm::ffi::json::Object config;\n  config.Set(\"type\", type);\n  if (schema.has_value()) {\n    config.Set(\"schema\", schema.value());\n  }\n  return config;\n}\n\n/****************** DisaggConfig ******************/\n\nResult<DisaggConfig> DisaggConfig::FromJSON(const tvm::ffi::json::Object& config) {\n  using TResult = Result<DisaggConfig>;\n  DisaggConfig res;\n  std::optional<std::string> kind = json::LookupOptional<std::string>(config, \"kind\");\n  if (kind.has_value()) {\n    if (kind.value() == \"prepare_receive\") {\n      res.kind = DisaggRequestKind::kPrepareReceive;\n    } else if (kind.value() == \"remote_send\") {\n      res.kind = DisaggRequestKind::kRemoteSend;\n    } else if (kind.value() == \"start_generation\") {\n      res.kind = DisaggRequestKind::kStartGeneration;\n    } else {\n      return TResult::Error(\"Unknown disaggregation request kind \" + kind.value());\n    }\n  }\n  std::optional<std::string> kv_append_metadata_encoded =\n      json::LookupOptional<std::string>(config, \"kv_append_metadata\");\n  if (kv_append_metadata_encoded.has_value()) {\n    tvm::ffi::String err;\n    auto parse_result =\n        tvm::ffi::json::Parse(Base64Decode(kv_append_metadata_encoded.value()), &err);\n    if (!err.empty()) {\n      return TResult::Error(\"kv_append_metadata parse error: \" + std::string(err));\n    }\n    if (!parse_result.try_cast<tvm::ffi::json::Array>().has_value()) {\n      return TResult::Error(\"kv_append_metadata is not array of integer.\");\n    }\n    tvm::ffi::json::Array kv_append_metadata_arr = parse_result.cast<tvm::ffi::json::Array>();\n    std::vector<IntTuple> kv_append_metadata;\n    int ptr = 0;\n    while (ptr < static_cast<int>(kv_append_metadata_arr.size())) {\n      if (!kv_append_metadata_arr[ptr].try_cast<int64_t>().has_value()) {\n        return TResult::Error(\"Invalid kv append metadata value in kv_append_metadata array\");\n      }\n      int num_segments = kv_append_metadata_arr[ptr].cast<int64_t>();\n      if (ptr + num_segments * 2 + 1 > static_cast<int>(kv_append_metadata_arr.size())) {\n        return TResult::Error(\"Invalid kv append metadata compression in kv_append_metadata\");\n      }\n      std::vector<int64_t> compressed_kv_append_metadata{num_segments};\n      compressed_kv_append_metadata.reserve(num_segments * 2 + 1);\n      for (int i = 1; i <= num_segments * 2; ++i) {\n        if (!kv_append_metadata_arr[ptr + i].try_cast<int64_t>().has_value()) {\n          return TResult::Error(\"Invalid kv append metadata value in kv_append_metadata array\");\n        }\n        compressed_kv_append_metadata.push_back(kv_append_metadata_arr[ptr + i].cast<int64_t>());\n      }\n      kv_append_metadata.push_back(IntTuple(std::move(compressed_kv_append_metadata)));\n      ptr += num_segments * 2 + 1;\n    }\n    res.kv_append_metadata = std::move(kv_append_metadata);\n  }\n  res.kv_window_begin = json::LookupOptional<int64_t>(config, \"kv_window_begin\");\n  res.kv_window_end = json::LookupOptional<int64_t>(config, \"kv_window_end\");\n  res.dst_group_offset = json::LookupOptional<int64_t>(config, \"dst_group_offset\");\n  return TResult::Ok(res);\n}\n\ntvm::ffi::json::Object DisaggConfig::AsJSON() const {\n  tvm::ffi::json::Object config;\n  switch (kind) {\n    case DisaggRequestKind::kPrepareReceive: {\n      config.Set(\"kind\", \"prepare_receive\");\n      break;\n    }\n    case DisaggRequestKind::kRemoteSend: {\n      config.Set(\"kind\", \"remote_send\");\n      break;\n    }\n    case DisaggRequestKind::kStartGeneration: {\n      config.Set(\"kind\", \"start_generation\");\n      break;\n    }\n    default:\n      break;\n  }\n  if (!kv_append_metadata.empty()) {\n    tvm::ffi::json::Array kv_append_metadata_arr;\n    for (const IntTuple& compressed_kv_append_metadata : kv_append_metadata) {\n      for (int64_t value : compressed_kv_append_metadata) {\n        kv_append_metadata_arr.push_back(value);\n      }\n    }\n    config.Set(\"kv_append_metadata\",\n               Base64Encode(tvm::ffi::json::Stringify(kv_append_metadata_arr)));\n  }\n  if (kv_window_begin.has_value()) {\n    config.Set(\"kv_window_begin\", static_cast<int64_t>(kv_window_begin.value()));\n  }\n  if (kv_window_end.has_value()) {\n    config.Set(\"kv_window_end\", static_cast<int64_t>(kv_window_end.value()));\n  }\n  if (dst_group_offset.has_value()) {\n    config.Set(\"dst_group_offset\", static_cast<int64_t>(dst_group_offset.value()));\n  }\n  return config;\n}\n\n/****************** DebugConfig ******************/\n\nResult<DebugConfig> DebugConfig::FromJSON(const tvm::ffi::json::Object& config) {\n  using TResult = Result<DebugConfig>;\n  DebugConfig res;\n  res.ignore_eos = json::LookupOrDefault<bool>(config, \"ignore_eos\", false);\n  res.pinned_system_prompt = json::LookupOrDefault<bool>(config, \"pinned_system_prompt\", false);\n  std::string special_request = json::LookupOrDefault<std::string>(config, \"special_request\", \"\");\n  if (special_request.length() != 0) {\n    if (special_request == \"query_engine_metrics\") {\n      res.special_request = SpecialRequestKind::kQueryEngineMetrics;\n    } else {\n      return TResult::Error(\"Unknown special request \" + special_request);\n    }\n  }\n  std::string grammar_execution_mode =\n      json::LookupOrDefault<std::string>(config, \"grammar_execution_mode\", \"jump_forward\");\n  if (grammar_execution_mode == \"jump_forward\") {\n    res.grammar_execution_mode = GrammarExecutionMode::kJumpForward;\n  } else if (grammar_execution_mode == \"constraint\") {\n    res.grammar_execution_mode = GrammarExecutionMode::kConstraint;\n  } else {\n    return TResult::Error(\"Unknown grammar execution mode \" + grammar_execution_mode);\n  }\n  if (auto disagg_config_obj =\n          json::LookupOptional<tvm::ffi::json::Object>(config, \"disagg_config\")) {\n    Result<DisaggConfig> disagg_config = DisaggConfig::FromJSON(disagg_config_obj.value());\n    if (disagg_config.IsErr()) {\n      return TResult::Error(disagg_config.UnwrapErr());\n    }\n    res.disagg_config = disagg_config.Unwrap();\n  }\n  return TResult::Ok(res);\n}\n\n/**\n * \\return serialized json value of the config.\n */\ntvm::ffi::json::Object DebugConfig::AsJSON() const {\n  tvm::ffi::json::Object config;\n  config.Set(\"ignore_eos\", ignore_eos);\n  config.Set(\"pinned_system_prompt\", pinned_system_prompt);\n  switch (special_request) {\n    case SpecialRequestKind::kQueryEngineMetrics: {\n      config.Set(\"special_request\", \"query_engine_metrics\");\n      break;\n    }\n    case SpecialRequestKind::kNone:\n      break;\n  }\n  switch (grammar_execution_mode) {\n    case GrammarExecutionMode::kJumpForward: {\n      config.Set(\"grammar_execution_mode\", \"jump_forward\");\n      break;\n    }\n    case GrammarExecutionMode::kConstraint: {\n      config.Set(\"grammar_execution_mode\", \"constraint\");\n      break;\n    }\n  }\n  if (disagg_config.kind != DisaggRequestKind::kNone) {\n    config.Set(\"disagg_config\", disagg_config.AsJSON());\n  }\n  return config;\n}\n\n/****************** GenerationConfig ******************/\n\nResult<GenerationConfig> GenerationConfig::Validate(GenerationConfig cfg) {\n  using TResult = Result<GenerationConfig>;\n  if (cfg->n <= 0) {\n    return TResult::Error(\"\\\"n\\\" should be at least 1\");\n  }\n  if (cfg->temperature < 0) {\n    return TResult::Error(\"\\\"temperature\\\" should be non-negative\");\n  }\n  if (cfg->top_p < 0 || cfg->top_p > 1) {\n    return TResult::Error(\"\\\"top_p\\\" should be in range [0, 1]\");\n  }\n  if (std::fabs(cfg->frequency_penalty) > 2.0) {\n    return TResult::Error(\"frequency_penalty must be in [-2, 2]!\");\n  }\n  if (cfg->repetition_penalty <= 0) {\n    return TResult::Error(\"\\\"repetition_penalty\\\" must be positive\");\n  }\n  if (cfg->top_logprobs < 0 || cfg->top_logprobs > 20) {\n    return TResult::Error(\"At most 20 top logprob tokens are supported\");\n  }\n  if (cfg->top_logprobs != 0 && !(cfg->logprobs)) {\n    return TResult::Error(\"\\\"logprobs\\\" must be true to support \\\"top_logprobs\\\"\");\n  }\n  for (const auto& item : cfg->logit_bias) {\n    double bias_value = item.second;\n    if (std::fabs(bias_value) > 100.0) {\n      return TResult::Error(\"Logit bias value should be in range [-100, 100].\");\n    }\n  }\n  return TResult::Ok(cfg);\n}\n\nResult<GenerationConfig> GenerationConfig::FromJSON(const tvm::ffi::json::Object& config,\n                                                    const GenerationConfig& default_config) {\n  using TResult = Result<GenerationConfig>;\n  ObjectPtr<GenerationConfigNode> n = tvm::ffi::make_object<GenerationConfigNode>();\n  n->n = json::LookupOrDefault<int64_t>(config, \"n\", default_config->n);\n  n->temperature =\n      json::LookupOrDefault<double>(config, \"temperature\", default_config->temperature);\n  n->top_p = json::LookupOrDefault<double>(config, \"top_p\", default_config->top_p);\n  n->frequency_penalty =\n      json::LookupOrDefault<double>(config, \"frequency_penalty\", default_config->frequency_penalty);\n  n->presence_penalty =\n      json::LookupOrDefault<double>(config, \"presence_penalty\", default_config->presence_penalty);\n  n->repetition_penalty = json::LookupOrDefault<double>(config, \"repetition_penalty\",\n                                                        default_config->repetition_penalty);\n  n->logprobs = json::LookupOrDefault<bool>(config, \"logprobs\", default_config->logprobs);\n  n->top_logprobs =\n      json::LookupOrDefault<int64_t>(config, \"top_logprobs\", default_config->top_logprobs);\n\n  std::optional<tvm::ffi::json::Object> logit_bias_obj =\n      json::LookupOptional<tvm::ffi::json::Object>(config, \"logit_bias\");\n  if (logit_bias_obj.has_value()) {\n    std::vector<std::pair<int, float>> logit_bias;\n    logit_bias.reserve(logit_bias_obj.value().size());\n    for (auto [k, v] : logit_bias_obj.value()) {\n      std::string token_id_str(k.cast<tvm::ffi::String>());\n      TVM_FFI_ICHECK(v.try_cast<double>().has_value());\n      double bias_value = v.cast<double>();\n      logit_bias.emplace_back(std::stoi(token_id_str), bias_value);\n    }\n    n->logit_bias = std::move(logit_bias);\n  } else {\n    n->logit_bias = default_config->logit_bias;\n  }\n\n  n->seed = json::LookupOrDefault<int64_t>(config, \"seed\", std::random_device{}());\n  // \"-1\" means the generation will not stop until exceeding\n  // model capability or hit any stop criteria.\n  n->max_tokens = json::LookupOrDefault<int64_t>(config, \"max_tokens\", -1);\n\n  std::optional<tvm::ffi::json::Array> stop_strs_arr =\n      json::LookupOptional<tvm::ffi::json::Array>(config, \"stop_strs\");\n  if (stop_strs_arr.has_value()) {\n    Array<String> stop_strs;\n    stop_strs.reserve(stop_strs_arr.value().size());\n    for (const auto& v : stop_strs_arr.value()) {\n      if (!v.try_cast<std::string>().has_value()) {\n        return TResult::Error(\"Invalid stop string in stop_strs\");\n      }\n      stop_strs.push_back(v.cast<std::string>());\n    }\n    n->stop_strs = std::move(stop_strs);\n  } else {\n    n->stop_strs = default_config->stop_strs;\n  }\n  std::optional<tvm::ffi::json::Array> stop_token_ids_arr =\n      json::LookupOptional<tvm::ffi::json::Array>(config, \"stop_token_ids\");\n  if (stop_token_ids_arr.has_value()) {\n    std::vector<int> stop_token_ids;\n    stop_token_ids.reserve(stop_token_ids_arr.value().size());\n    for (const auto& v : stop_token_ids_arr.value()) {\n      if (!v.try_cast<int64_t>().has_value()) {\n        return TResult::Error(\"Invalid stop token in stop_token_ids\");\n      }\n      stop_token_ids.push_back(v.cast<int64_t>());\n    }\n    n->stop_token_ids = std::move(stop_token_ids);\n  } else {\n    n->stop_token_ids = default_config->stop_token_ids;\n  }\n\n  std::optional<tvm::ffi::json::Object> response_format_obj =\n      json::LookupOptional<tvm::ffi::json::Object>(config, \"response_format\");\n  if (response_format_obj.has_value()) {\n    Result<ResponseFormat> response_format_res =\n        ResponseFormat::FromJSON(response_format_obj.value());\n    if (response_format_res.IsErr()) {\n      return TResult::Error(response_format_res.UnwrapErr());\n    }\n    n->response_format = response_format_res.Unwrap();\n  } else {\n    n->response_format = default_config->response_format;\n  }\n  // \"debug_config\" is for internal usage. Not the part of OpenAI API spec.\n  std::optional<tvm::ffi::json::Object> debug_config_obj =\n      json::LookupOptional<tvm::ffi::json::Object>(config, \"debug_config\");\n\n  if (debug_config_obj.has_value()) {\n    Result<DebugConfig> debug_config_res = DebugConfig::FromJSON(debug_config_obj.value());\n    if (debug_config_res.IsErr()) {\n      return TResult::Error(debug_config_res.UnwrapErr());\n    }\n    n->debug_config = debug_config_res.Unwrap();\n  }\n  return Validate(GenerationConfig(n));\n}\n\nGenerationConfig GenerationConfig::GetDefaultFromModelConfig(\n    const tvm::ffi::json::Object& model_config_json) {\n  ObjectPtr<GenerationConfigNode> n = tvm::ffi::make_object<GenerationConfigNode>();\n  n->max_tokens = -1;\n  n->temperature = json::LookupOrDefault<double>(model_config_json, \"temperature\", n->temperature);\n  n->top_p = json::LookupOrDefault<double>(model_config_json, \"top_p\", n->top_p);\n  n->frequency_penalty =\n      json::LookupOrDefault<double>(model_config_json, \"frequency_penalty\", n->frequency_penalty);\n  n->presence_penalty =\n      json::LookupOrDefault<double>(model_config_json, \"presence_penalty\", n->presence_penalty);\n  return GenerationConfig(n);\n}\n\ntvm::ffi::json::Object GenerationConfigNode::AsJSON() const {\n  tvm::ffi::json::Object config;\n  config.Set(\"n\", static_cast<int64_t>(this->n));\n  config.Set(\"temperature\", this->temperature);\n  config.Set(\"top_p\", this->top_p);\n  config.Set(\"frequency_penalty\", this->frequency_penalty);\n  config.Set(\"presence_penalty\", this->presence_penalty);\n  config.Set(\"repetition_penalty\", this->repetition_penalty);\n  config.Set(\"logprobs\", this->logprobs);\n  config.Set(\"top_logprobs\", static_cast<int64_t>(this->top_logprobs));\n  config.Set(\"max_tokens\", static_cast<int64_t>(this->max_tokens));\n  config.Set(\"seed\", static_cast<int64_t>(this->seed));\n\n  tvm::ffi::json::Object logit_bias_obj;\n  for (auto [token_id, bias] : logit_bias) {\n    logit_bias_obj.Set(std::to_string(token_id), static_cast<double>(bias));\n  }\n  config.Set(\"logit_bias\", logit_bias_obj);\n\n  tvm::ffi::json::Array stop_strs_arr;\n  for (String stop_str : this->stop_strs) {\n    stop_strs_arr.push_back(stop_str);\n  }\n  config.Set(\"stop_strs\", stop_strs_arr);\n\n  tvm::ffi::json::Array stop_token_ids_arr;\n  for (int stop_token_id : this->stop_token_ids) {\n    stop_token_ids_arr.push_back(static_cast<int64_t>(stop_token_id));\n  }\n  config.Set(\"stop_token_ids\", stop_token_ids_arr);\n\n  tvm::ffi::json::Object response_format;\n  response_format.Set(\"type\", this->response_format.type);\n  if (this->response_format.schema) {\n    response_format.Set(\"schema\", this->response_format.schema.value());\n  } else {\n    response_format.Set(\"schema\", tvm::Any(nullptr));\n  }\n  config.Set(\"response_format\", response_format);\n  config.Set(\"debug_config\", debug_config.AsJSON());\n  return config;\n}\n\n/****************** EngineConfig ******************/\n\nEngineConfig EngineConfig::FromJSONAndInferredConfig(\n    const tvm::ffi::json::Object& json, const InferrableEngineConfig& inferred_config) {\n  TVM_FFI_ICHECK(inferred_config.max_num_sequence.has_value());\n  TVM_FFI_ICHECK(inferred_config.max_total_sequence_length.has_value());\n  TVM_FFI_ICHECK(inferred_config.prefill_chunk_size.has_value());\n  TVM_FFI_ICHECK(inferred_config.max_history_size.has_value());\n  ObjectPtr<EngineConfigNode> n = tvm::ffi::make_object<EngineConfigNode>();\n\n  // - Get models and model libs.\n  n->model = json::Lookup<std::string>(json, \"model\");\n  n->model_lib = json::Lookup<std::string>(json, \"model_lib\");\n  std::vector<String> additional_models;\n  std::vector<String> additional_model_libs;\n  tvm::ffi::json::Array additional_models_arr = json::LookupOrDefault<tvm::ffi::json::Array>(\n      json, \"additional_models\", tvm::ffi::json::Array());\n  int num_additional_models = additional_models_arr.size();\n  additional_models.reserve(num_additional_models);\n  additional_model_libs.reserve(num_additional_models);\n  for (int i = 0; i < num_additional_models; ++i) {\n    tvm::ffi::json::Array additional_model_pair =\n        json::Lookup<tvm::ffi::json::Array>(additional_models_arr, i);\n    additional_models.push_back(json::Lookup<std::string>(additional_model_pair, 0));\n    additional_model_libs.push_back(json::Lookup<std::string>(additional_model_pair, 1));\n  }\n  n->additional_models = additional_models;\n  n->additional_model_libs = additional_model_libs;\n  n->mode = EngineModeFromString(json::Lookup<std::string>(json, \"mode\"));\n\n  // - Other fields with default value.\n  n->gpu_memory_utilization = static_cast<float>(\n      json::LookupOrDefault<double>(json, \"gpu_memory_utilization\", n->gpu_memory_utilization));\n  n->kv_cache_page_size = static_cast<int>(\n      json::LookupOrDefault<int64_t>(json, \"kv_cache_page_size\", n->kv_cache_page_size));\n  n->speculative_mode = SpeculativeModeFromString(json::LookupOrDefault<std::string>(\n      json, \"speculative_mode\", SpeculativeModeToString(n->speculative_mode)));\n  n->spec_draft_length = static_cast<int>(\n      json::LookupOrDefault<int64_t>(json, \"spec_draft_length\", n->spec_draft_length));\n  n->spec_tree_width =\n      static_cast<int>(json::LookupOrDefault<int64_t>(json, \"spec_tree_width\", n->spec_tree_width));\n  n->prefill_mode = PrefillModeFromString(json::LookupOrDefault<std::string>(\n      json, \"prefill_mode\", PrefillModeToString(n->prefill_mode)));\n  n->verbose = json::LookupOrDefault<bool>(json, \"verbose\", n->verbose);\n\n  // - Fields from the inferred engine config.\n  n->max_num_sequence = inferred_config.max_num_sequence.value();\n  n->max_total_sequence_length = inferred_config.max_total_sequence_length.value();\n  if (inferred_config.max_single_sequence_length.has_value()) {\n    n->max_single_sequence_length = inferred_config.max_single_sequence_length.value();\n  }\n  n->prefill_chunk_size = inferred_config.prefill_chunk_size.value();\n  n->max_history_size = inferred_config.max_history_size.value();\n\n  n->prefix_cache_mode = PrefixCacheModeFromString(json::LookupOrDefault<std::string>(\n      json, \"prefix_cache_mode\", PrefixCacheModeToString(n->prefix_cache_mode)));\n  n->prefix_cache_max_num_recycling_seqs = static_cast<int>(json::LookupOrDefault<int64_t>(\n      json, \"prefix_cache_max_num_recycling_seqs\", n->max_num_sequence));\n  return EngineConfig(n);\n}\n\nResult<std::vector<std::pair<std::string, std::string>>>\nEngineConfig::GetModelsAndModelLibsFromJSONString(const std::string& json_str) {\n  using TResult = Result<std::vector<std::pair<std::string, std::string>>>;\n  tvm::ffi::String err;\n  auto config_json = tvm::ffi::json::Parse(json_str, &err);\n  if (!err.empty()) {\n    return TResult::Error(err);\n  }\n\n  // Get the models and model libs from JSON.\n  tvm::ffi::json::Object config = config_json.cast<tvm::ffi::json::Object>();\n  String model = json::Lookup<std::string>(config, \"model\");\n  String model_lib = json::Lookup<std::string>(config, \"model_lib\");\n  tvm::ffi::json::Array additional_models_arr = json::LookupOrDefault<tvm::ffi::json::Array>(\n      config, \"additional_models\", tvm::ffi::json::Array());\n\n  int num_additional_models = additional_models_arr.size();\n  std::vector<std::pair<std::string, std::string>> models_and_model_libs;\n  models_and_model_libs.reserve(num_additional_models + 1);\n  models_and_model_libs.emplace_back(model, model_lib);\n  for (int i = 0; i < num_additional_models; ++i) {\n    tvm::ffi::json::Array additional_model_pair =\n        json::Lookup<tvm::ffi::json::Array>(additional_models_arr, i);\n    models_and_model_libs.emplace_back(json::Lookup<std::string>(additional_model_pair, 0),\n                                       json::Lookup<std::string>(additional_model_pair, 1));\n  }\n  return TResult::Ok(models_and_model_libs);\n}\n\nString EngineConfigNode::AsJSONString() const {\n  tvm::ffi::json::Object config;\n\n  // - Models and model libs\n  config.Set(\"model\", this->model);\n  config.Set(\"model_lib\", this->model_lib);\n  tvm::ffi::json::Array additional_models_arr;\n  additional_models_arr.reserve(this->additional_models.size());\n  for (int i = 0; i < static_cast<int>(this->additional_models.size()); ++i) {\n    tvm::ffi::json::Array pair;\n    pair.push_back(this->additional_models[i]);\n    pair.push_back(this->additional_model_libs[i]);\n    additional_models_arr.push_back(pair);\n  }\n  config.Set(\"additional_models\", additional_models_arr);\n\n  // - Other fields\n  config.Set(\"mode\", EngineModeToString(this->mode));\n  config.Set(\"gpu_memory_utilization\", static_cast<double>(this->gpu_memory_utilization));\n  config.Set(\"kv_cache_page_size\", static_cast<int64_t>(this->kv_cache_page_size));\n  config.Set(\"max_num_sequence\", static_cast<int64_t>(this->max_num_sequence));\n  config.Set(\"max_total_sequence_length\", static_cast<int64_t>(this->max_total_sequence_length));\n  config.Set(\"max_single_sequence_length\", static_cast<int64_t>(this->max_single_sequence_length));\n  config.Set(\"prefill_chunk_size\", static_cast<int64_t>(this->prefill_chunk_size));\n  config.Set(\"max_history_size\", static_cast<int64_t>(this->max_history_size));\n  config.Set(\"prefix_cache_mode\", PrefixCacheModeToString(this->prefix_cache_mode));\n  config.Set(\"prefix_cache_max_num_recycling_seqs\",\n             static_cast<int64_t>(this->prefix_cache_max_num_recycling_seqs));\n  config.Set(\"speculative_mode\", SpeculativeModeToString(this->speculative_mode));\n  config.Set(\"spec_draft_length\", static_cast<int64_t>(this->spec_draft_length));\n  config.Set(\"prefill_mode\", PrefillModeToString(this->prefill_mode));\n  config.Set(\"verbose\", static_cast<bool>(this->verbose));\n\n  return tvm::ffi::json::Stringify(config, 2);\n}\n\n/****************** InferrableEngineConfig ******************/\n\n/*! \\brief The class for config limitation from models. */\nstruct ModelConfigLimits {\n  int64_t model_compile_time_max_single_sequence_length;\n  int64_t model_runtime_max_single_sequence_length;\n  int64_t model_compile_time_max_prefill_chunk_size;\n  int64_t model_runtime_max_prefill_chunk_size;\n  int64_t model_max_sliding_window_size;\n  int64_t model_max_batch_size;\n};\n\n/*! \\brief Convert the bytes to megabytes, keeping 3 decimals. */\ninline std::string BytesToMegabytesString(double bytes) {\n  std::ostringstream os;\n  os << std::setprecision(3) << std::fixed << (bytes / 1024 / 1024);\n  return os.str();\n}\n\n/*!\n * \\brief Get the upper bound of single sequence length, prefill size and batch size\n * from model config.\n */\nResult<ModelConfigLimits> GetModelConfigLimits(\n    const std::vector<tvm::ffi::json::Object>& model_configs,\n    const std::vector<ModelMetadata>& model_metadata) {\n  TVM_FFI_ICHECK_EQ(model_configs.size(), model_metadata.size());\n  int64_t model_compile_time_max_single_sequence_length = std::numeric_limits<int64_t>::max();\n  int64_t model_runtime_max_single_sequence_length = std::numeric_limits<int64_t>::max();\n  int64_t model_compile_time_max_prefill_chunk_size = std::numeric_limits<int64_t>::max();\n  int64_t model_runtime_max_prefill_chunk_size = std::numeric_limits<int64_t>::max();\n  int64_t model_max_batch_size = std::numeric_limits<int64_t>::max();\n  int64_t model_max_sliding_window_size = std::numeric_limits<int64_t>::max();\n  for (int i = 0; i < static_cast<int>(model_configs.size()); ++i) {\n    // - The maximum single sequence length is the minimum context window size among all models.\n    int64_t runtime_context_window_size =\n        json::LookupOptional<int64_t>(model_configs[i], \"context_window_size\").value_or(-1);\n    int64_t compile_time_context_window_size = model_metadata[i].context_window_size;\n\n    // limit runtime setting by compile time setting\n    if (compile_time_context_window_size != -1) {\n      if (runtime_context_window_size == -1 ||\n          runtime_context_window_size > compile_time_context_window_size) {\n        runtime_context_window_size = compile_time_context_window_size;\n      }\n    }\n\n    if (compile_time_context_window_size != -1) {\n      model_compile_time_max_single_sequence_length =\n          std::min(model_compile_time_max_single_sequence_length, compile_time_context_window_size);\n    }\n    if (runtime_context_window_size != -1) {\n      model_runtime_max_single_sequence_length =\n          std::min(model_runtime_max_single_sequence_length, runtime_context_window_size);\n    }\n    // - The maximum prefill chunk size is the minimum prefill chunk size among all models.\n    int64_t runtime_prefill_chunk_size =\n        json::Lookup<int64_t>(model_configs[i], \"prefill_chunk_size\");\n    int64_t compile_time_prefill_chunk_size = model_metadata[i].prefill_chunk_size;\n\n    // limit runtime setting by compile time setting\n    if (compile_time_prefill_chunk_size != -1) {\n      if (runtime_prefill_chunk_size == -1 ||\n          runtime_prefill_chunk_size > compile_time_prefill_chunk_size) {\n        runtime_prefill_chunk_size = compile_time_prefill_chunk_size;\n      }\n    }\n\n    if (compile_time_prefill_chunk_size != -1) {\n      model_compile_time_max_prefill_chunk_size =\n          std::min(model_compile_time_max_prefill_chunk_size, compile_time_prefill_chunk_size);\n    }\n    if (runtime_prefill_chunk_size != -1) {\n      model_runtime_max_prefill_chunk_size =\n          std::min(model_runtime_max_prefill_chunk_size, runtime_prefill_chunk_size);\n    }\n    // - The maximum batch size is the minimum max batch size among all models.\n    model_max_batch_size = std::min(model_max_batch_size, model_metadata[i].max_batch_size);\n    // - The maximum sliding window size is the minimum among all models.\n    int64_t runtime_sliding_window_size =\n        json::LookupOptional<int64_t>(model_configs[i], \"sliding_window_size\").value_or(-1);\n    if (runtime_sliding_window_size != -1) {\n      model_max_sliding_window_size =\n          std::min(model_max_sliding_window_size, runtime_sliding_window_size);\n    }\n  }\n  TVM_FFI_ICHECK_NE(model_compile_time_max_prefill_chunk_size, std::numeric_limits<int64_t>::max());\n  TVM_FFI_ICHECK_NE(model_runtime_max_prefill_chunk_size, std::numeric_limits<int64_t>::max());\n  TVM_FFI_ICHECK_NE(model_max_batch_size, std::numeric_limits<int64_t>::max());\n  TVM_FFI_ICHECK_GT(model_compile_time_max_prefill_chunk_size, 0);\n  TVM_FFI_ICHECK_GT(model_runtime_max_prefill_chunk_size, 0);\n  TVM_FFI_ICHECK_GT(model_max_batch_size, 0);\n  return Result<ModelConfigLimits>::Ok(\n      {model_compile_time_max_single_sequence_length, model_runtime_max_single_sequence_length,\n       model_compile_time_max_prefill_chunk_size, model_runtime_max_prefill_chunk_size,\n       model_max_sliding_window_size, model_max_batch_size});\n}\n\n/*! \\brief The class for memory usage estimation result. */\nstruct MemUsageEstimationResult {\n  double total_memory_bytes;\n  double kv_cache_memory_bytes;\n  double temp_memory_bytes;\n  InferrableEngineConfig inferred_config;\n};\n\nResult<MemUsageEstimationResult> EstimateMemoryUsageOnMode(\n    EngineMode mode, Device device, double gpu_memory_utilization, int64_t params_bytes,\n    int64_t temp_buffer_bytes,\n    const std::vector<tvm::ffi::json::Object>& model_configs,  //\n    const std::vector<ModelMetadata>& model_metadata,          //\n    ModelConfigLimits model_config_limits,                     //\n    InferrableEngineConfig init_config, bool verbose) {\n  std::ostringstream os;\n  InferrableEngineConfig inferred_config = init_config;\n  // - 1. max_num_sequence\n  if (!init_config.max_num_sequence.has_value()) {\n    if (mode == EngineMode::kLocal) {\n      inferred_config.max_num_sequence =\n          std::min(static_cast<int64_t>(4), model_config_limits.model_max_batch_size);\n    } else if (mode == EngineMode::kInteractive) {\n      inferred_config.max_num_sequence = 1;\n    } else {\n      inferred_config.max_num_sequence = model_config_limits.model_max_batch_size;\n    }\n    os << \"max batch size will be set to \" << inferred_config.max_num_sequence.value() << \", \";\n  } else {\n    os << \"max batch size \" << inferred_config.max_num_sequence.value()\n       << \" is specified by user, \";\n  }\n  int64_t max_num_sequence = inferred_config.max_num_sequence.value();\n  // - 2. max_single_sequence_length\n  if (!init_config.max_single_sequence_length.has_value()) {\n    inferred_config.max_single_sequence_length =\n        model_config_limits.model_runtime_max_single_sequence_length;\n  } else {\n    inferred_config.max_single_sequence_length =\n        std::min(inferred_config.max_single_sequence_length.value(),\n                 model_config_limits.model_compile_time_max_single_sequence_length);\n  }\n  // - 3. infer the maximum total sequence length that can fit GPU memory.\n  double kv_bytes_per_token = 0;\n  double kv_aux_workspace_bytes = 0;\n  double model_workspace_bytes = 0;\n  double logit_processor_workspace_bytes = 0;\n  TVM_FFI_ICHECK_EQ(model_configs.size(), model_metadata.size());\n  int num_models = model_configs.size();\n  for (int i = 0; i < num_models; ++i) {\n    // - Read the vocab size and compile-time prefill chunk size (which affects memory allocation).\n    tvm::ffi::json::Object compile_time_model_config =\n        json::Lookup<tvm::ffi::json::Object>(model_configs[i], \"model_config\");\n    int64_t vocab_size = json::Lookup<int64_t>(compile_time_model_config, \"vocab_size\");\n    int64_t prefill_chunk_size =\n        json::Lookup<int64_t>(compile_time_model_config, \"prefill_chunk_size\");\n    // - Calculate KV cache memory usage.\n    int64_t num_layers = model_metadata[i].kv_cache_metadata.num_hidden_layers;\n    int64_t head_dim = model_metadata[i].kv_cache_metadata.head_dim;\n    int64_t num_qo_heads = model_metadata[i].kv_cache_metadata.num_attention_heads;\n    int64_t num_kv_heads = model_metadata[i].kv_cache_metadata.num_key_value_heads;\n    int64_t hidden_size = head_dim * num_qo_heads;\n    kv_bytes_per_token +=\n        head_dim * num_kv_heads * (num_layers / model_metadata[i].pipeline_parallel_stages) * 4 +\n        1.25;\n    kv_aux_workspace_bytes +=\n        (max_num_sequence + 1) * 88 + prefill_chunk_size * (num_qo_heads + 1) * 8 +\n        prefill_chunk_size * head_dim * (num_qo_heads + num_kv_heads) * 4 + 48 * 1024 * 1024;\n    model_workspace_bytes += prefill_chunk_size * 4 + max_num_sequence * 4 +\n                             (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2;\n    logit_processor_workspace_bytes +=\n        max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125;\n  }\n  int64_t gpu_size_bytes = TotalDetectGlobalMemory(device);\n  // Compute the maximum total sequence length under the GPU memory budget.\n  int64_t model_max_total_sequence_length =\n      static_cast<int>((gpu_size_bytes * gpu_memory_utilization  //\n                        - params_bytes                           //\n                        - temp_buffer_bytes                      //\n                        - kv_aux_workspace_bytes                 //\n                        - model_workspace_bytes                  //\n                        - logit_processor_workspace_bytes) /\n                       kv_bytes_per_token);\n  if (model_max_total_sequence_length <= 0) {\n    if (verbose) {\n      LOG(INFO) << \"temp_buffer = \" << BytesToMegabytesString(temp_buffer_bytes);\n      LOG(INFO) << \"kv_aux workspace = \" << BytesToMegabytesString(kv_aux_workspace_bytes);\n      LOG(INFO) << \"model workspace = \" << BytesToMegabytesString(model_workspace_bytes);\n      LOG(INFO) << \"logit processor workspace = \"\n                << BytesToMegabytesString(logit_processor_workspace_bytes);\n    }\n    return Result<MemUsageEstimationResult>::Error(\n        \"Insufficient GPU memory error: \"\n        \"The available single GPU memory is \" +\n        BytesToMegabytesString(gpu_size_bytes * gpu_memory_utilization) +\n        \" MB, \"\n        \"which is less than the sum of model weight size (\" +\n        BytesToMegabytesString(params_bytes) + \" MB) and temporary buffer size (\" +\n        BytesToMegabytesString(temp_buffer_bytes + kv_aux_workspace_bytes + model_workspace_bytes +\n                               logit_processor_workspace_bytes) +\n        \" MB).\\n\"\n        \"1. You can set a larger \\\"gpu_memory_utilization\\\" value.\\n\"\n        \"2. If the model weight size is too large, please enable tensor parallelism by passing \"\n        \"`--tensor-parallel-shards $NGPU` to `mlc_llm gen_config` or use quantization.\\n\"\n        \"3. If the temporary buffer size is too large, please use a smaller `--prefill-chunk-size` \"\n        \"in `mlc_llm gen_config`.\");\n  }\n  if (device.device_type == DLDeviceType::kDLMetal) {\n    // NOTE: Metal runtime has severe performance issues with large buffers.\n    // To work around the issue, we limit the KV cache capacity to 32768.\n    model_max_total_sequence_length =\n        std::min(model_max_total_sequence_length, static_cast<int64_t>(32768));\n  }\n  // Compute the total memory usage except the KV cache part.\n  double total_mem_usage_except_kv_cache =\n      (params_bytes + temp_buffer_bytes + kv_aux_workspace_bytes + model_workspace_bytes +\n       logit_processor_workspace_bytes);\n\n  // - 4. max_total_sequence_length\n  if (!init_config.max_total_sequence_length.has_value()) {\n    if (mode == EngineMode::kLocal) {\n      inferred_config.max_total_sequence_length = std::min(\n          {model_max_total_sequence_length, inferred_config.max_single_sequence_length.value(),\n           static_cast<int64_t>(8192)});\n    } else if (mode == EngineMode::kInteractive) {\n      inferred_config.max_total_sequence_length = std::min(\n          {model_max_total_sequence_length, inferred_config.max_single_sequence_length.value()});\n    } else {\n      inferred_config.max_total_sequence_length =\n          inferred_config.max_single_sequence_length.value() == std::numeric_limits<int64_t>::max()\n              ? model_max_total_sequence_length\n              : std::min(model_max_total_sequence_length,\n                         max_num_sequence * inferred_config.max_single_sequence_length.value());\n    }\n    os << \"max KV cache token capacity will be set to \"\n       << inferred_config.max_total_sequence_length.value() << \", \";\n  } else {\n    os << \"max KV cache token capacity \" << inferred_config.max_total_sequence_length.value()\n       << \" is specified by user, \";\n  }\n  // - 5. prefill_chunk_size\n  if (!init_config.prefill_chunk_size.has_value()) {\n    if (mode == EngineMode::kLocal || mode == EngineMode::kInteractive) {\n      inferred_config.prefill_chunk_size =\n          std::min({model_config_limits.model_runtime_max_prefill_chunk_size,\n                    inferred_config.max_total_sequence_length.value(),\n                    inferred_config.max_single_sequence_length.value()});\n    } else {\n      inferred_config.prefill_chunk_size = model_config_limits.model_runtime_max_prefill_chunk_size;\n    }\n    os << \"prefill chunk size will be set to \" << inferred_config.prefill_chunk_size.value()\n       << \". \";\n  } else {\n    os << \"prefill chunk size \" << inferred_config.prefill_chunk_size.value()\n       << \" is specified by user. \";\n  }\n\n  // - Print logging message\n  if (verbose) {\n    LOG(INFO) << \"Under mode \\\"\" << EngineModeToString(mode) << \"\\\", \" << os.str();\n  }\n\n  return Result<MemUsageEstimationResult>::Ok(\n      {total_mem_usage_except_kv_cache +\n           inferred_config.max_total_sequence_length.value() * kv_bytes_per_token,\n       kv_bytes_per_token * inferred_config.max_total_sequence_length.value() +\n           kv_aux_workspace_bytes,\n       model_workspace_bytes + logit_processor_workspace_bytes + temp_buffer_bytes,\n       inferred_config});\n}\n\nResult<InferrableEngineConfig> InferrableEngineConfig::InferForKVCache(\n    EngineMode mode, Device device, double gpu_memory_utilization,\n    const std::vector<tvm::ffi::json::Object>& model_configs,\n    const std::vector<ModelMetadata>& model_metadata, InferrableEngineConfig init_config,\n    bool verbose) {\n  // - Check if max_history_size is not set.\n  if (init_config.max_history_size.has_value() && init_config.max_history_size.value() != 0) {\n    return Result<InferrableEngineConfig>::Error(\n        \"KV cache does not support max_history_size, while it is set to \" +\n        std::to_string(init_config.max_history_size.value()) + \" in the input EngineConfig\");\n  }\n  // - Get the upper bound of single sequence length, prefill size and batch size\n  // from model config.\n  Result<ModelConfigLimits> model_config_limits_res =\n      GetModelConfigLimits(model_configs, model_metadata);\n\n  if (model_config_limits_res.IsErr()) {\n    return Result<InferrableEngineConfig>::Error(model_config_limits_res.UnwrapErr());\n  }\n  ModelConfigLimits model_config_limits = model_config_limits_res.Unwrap();\n  // - Get total model parameter size and temporary in-function buffer\n  // size in bytes on single GPU.\n  int64_t params_bytes = 0;\n  int64_t temp_buffer_bytes = 0;\n  for (const ModelMetadata& metadata : model_metadata) {\n    for (const ModelMetadata::Param& param : metadata.params) {\n      int64_t param_size = param.dtype.bytes();\n      for (int64_t v : param.shape) {\n        TVM_FFI_ICHECK_GE(v, 0);\n        param_size *= v;\n      }\n      params_bytes += param_size;\n    }\n    params_bytes /= metadata.pipeline_parallel_stages;\n    for (const auto& [func_name, temp_buffer_size] : metadata.memory_usage) {\n      temp_buffer_bytes = std::max(temp_buffer_bytes, temp_buffer_size);\n    }\n  }\n  // Magnify the temp buffer by a factor of 2 for safety.\n  temp_buffer_bytes *= 2;\n\n  // - Infer the engine config and estimate memory usage for each mode.\n  Result<MemUsageEstimationResult> local_mode_estimation_result = EstimateMemoryUsageOnMode(\n      EngineMode::kLocal, device, gpu_memory_utilization, params_bytes, temp_buffer_bytes,\n      model_configs, model_metadata, model_config_limits, init_config, verbose);\n  Result<MemUsageEstimationResult> interactive_mode_estimation_result = EstimateMemoryUsageOnMode(\n      EngineMode::kInteractive, device, gpu_memory_utilization, params_bytes, temp_buffer_bytes,\n      model_configs, model_metadata, model_config_limits, init_config, verbose);\n  Result<MemUsageEstimationResult> server_mode_estimation_result = EstimateMemoryUsageOnMode(\n      EngineMode::kServer, device, gpu_memory_utilization, params_bytes, temp_buffer_bytes,\n      model_configs, model_metadata, model_config_limits, init_config, verbose);\n  // - Pick the estimation result according to the mode.\n  std::string mode_name;\n  Result<MemUsageEstimationResult> final_estimation_result;\n  if (mode == EngineMode::kLocal) {\n    final_estimation_result = std::move(local_mode_estimation_result);\n  } else if (mode == EngineMode::kInteractive) {\n    final_estimation_result = std::move(interactive_mode_estimation_result);\n  } else {\n    final_estimation_result = std::move(server_mode_estimation_result);\n  }\n  if (final_estimation_result.IsErr()) {\n    return Result<InferrableEngineConfig>::Error(final_estimation_result.UnwrapErr());\n  }\n  // - Print log message.\n  MemUsageEstimationResult final_estimation = final_estimation_result.Unwrap();\n  InferrableEngineConfig inferred_config = std::move(final_estimation.inferred_config);\n\n  if (verbose) {\n    LOG(INFO) << \"The actual engine mode is \\\"\" << EngineModeToString(mode)\n              << \"\\\". So max batch size is \" << inferred_config.max_num_sequence.value()\n              << \", max KV cache token capacity is \"\n              << inferred_config.max_total_sequence_length.value() << \", prefill chunk size is \"\n              << inferred_config.prefill_chunk_size.value() << \".\";\n    LOG(INFO) << \"Estimated total single GPU memory usage: \"\n              << BytesToMegabytesString(final_estimation.total_memory_bytes)\n              << \" MB (Parameters: \" << BytesToMegabytesString(params_bytes)\n              << \" MB. KVCache: \" << BytesToMegabytesString(final_estimation.kv_cache_memory_bytes)\n              << \" MB. Temporary buffer: \"\n              << BytesToMegabytesString(final_estimation.temp_memory_bytes)\n              << \" MB). The actual usage might be slightly larger than the estimated number.\";\n  }\n\n  inferred_config.max_history_size = 0;\n  return Result<InferrableEngineConfig>::Ok(inferred_config);\n}\n\nResult<InferrableEngineConfig> InferrableEngineConfig::InferForRNNState(\n    EngineMode mode, Device device, double gpu_memory_utilization,\n    const std::vector<tvm::ffi::json::Object>& model_configs,\n    const std::vector<ModelMetadata>& model_metadata, InferrableEngineConfig init_config,\n    bool verbose) {\n  // - Check max_single_sequence_length is not set.\n  if (init_config.max_single_sequence_length.has_value()) {\n    return Result<InferrableEngineConfig>::Error(\n        \"RNN state does not support max_single_sequence_length, while it is set to \" +\n        std::to_string(init_config.max_single_sequence_length.value()) +\n        \" in the input EngineConfig\");\n  }\n  // - Get the upper bound of single sequence length, prefill size and batch size\n  // from model config.\n  Result<ModelConfigLimits> model_config_limits_res =\n      GetModelConfigLimits(model_configs, model_metadata);\n  if (model_config_limits_res.IsErr()) {\n    return Result<InferrableEngineConfig>::Error(model_config_limits_res.UnwrapErr());\n  }\n  ModelConfigLimits model_config_limits = model_config_limits_res.Unwrap();\n\n  std::ostringstream os;\n  InferrableEngineConfig inferred_config = init_config;\n  // - 1. prefill_chunk_size\n  if (!init_config.prefill_chunk_size.has_value()) {\n    inferred_config.prefill_chunk_size = std::min(\n        model_config_limits.model_runtime_max_prefill_chunk_size, static_cast<int64_t>(4096));\n    os << \"prefill chunk size will be set to \" << inferred_config.prefill_chunk_size.value()\n       << \", \";\n  } else {\n    os << \"prefill chunk size \" << inferred_config.prefill_chunk_size.value()\n       << \" is specified by user, \";\n  }\n  // - 2. max_batch_size\n  if (!init_config.max_num_sequence.has_value()) {\n    inferred_config.max_num_sequence =\n        mode == EngineMode::kInteractive\n            ? 1\n            : std::min(static_cast<int64_t>(4), model_config_limits.model_max_batch_size);\n    os << \"max batch size will be set to \" << inferred_config.max_num_sequence.value() << \", \";\n  } else {\n    os << \"max batch size \" << inferred_config.max_num_sequence.value()\n       << \" is specified by user, \";\n  }\n  int64_t max_num_sequence = inferred_config.max_num_sequence.value();\n  // - 3. max_total_sequence_length\n  if (!init_config.max_total_sequence_length.has_value()) {\n    inferred_config.max_total_sequence_length = 32768;\n    os << \"max RNN state token capacity will be set to \"\n       << inferred_config.max_total_sequence_length.value() << \". \";\n  } else {\n    os << \"max RNN state token capacity \" << inferred_config.max_total_sequence_length.value()\n       << \" is specified by user. \";\n  }\n\n  // - Extra logging message\n  if (mode == EngineMode::kLocal) {\n    os << \"We choose small max batch size and RNN state capacity to use less GPU memory.\";\n  } else if (mode == EngineMode::kInteractive) {\n    os << \"We fix max batch size to 1 for interactive single sequence use.\";\n  } else {\n    os << \"We use as much GPU memory as possible (within the limit of gpu_memory_utilization).\";\n  }\n  if (verbose) {\n    LOG(INFO) << \"Under mode \\\"\" << EngineModeToString(mode) << \"\\\", \" << os.str();\n  }\n\n  // - Get total model parameter size and temporary in-function buffer\n  // size in bytes on single GPU.\n  int64_t params_bytes = 0;\n  int64_t temp_buffer_bytes = 0;\n  for (const ModelMetadata& metadata : model_metadata) {\n    for (const ModelMetadata::Param& param : metadata.params) {\n      int64_t param_size = param.dtype.bytes();\n      for (int64_t v : param.shape) {\n        TVM_FFI_ICHECK_GE(v, 0);\n        param_size *= v;\n      }\n      params_bytes += param_size;\n    }\n    for (const auto& [func_name, temp_buffer_size] : metadata.memory_usage) {\n      temp_buffer_bytes += temp_buffer_size;\n    }\n  }\n  // - 4. max_history_size\n  double rnn_state_base_bytes = 0;  // The memory usage for rnn state when history = 1.\n  double model_workspace_bytes = 0;\n  double logit_processor_workspace_bytes = 0;\n  TVM_FFI_ICHECK_EQ(model_configs.size(), model_metadata.size());\n  int num_models = model_configs.size();\n  for (int i = 0; i < num_models; ++i) {\n    // - Read the vocab size and compile-time prefill chunk size (which affects memory allocation).\n    tvm::ffi::json::Object compile_time_model_config =\n        json::Lookup<tvm::ffi::json::Object>(model_configs[i], \"model_config\");\n    int64_t vocab_size = json::Lookup<int64_t>(compile_time_model_config, \"vocab_size\");\n    int64_t prefill_chunk_size =\n        json::Lookup<int64_t>(compile_time_model_config, \"prefill_chunk_size\");\n    int64_t head_size = json::Lookup<int64_t>(compile_time_model_config, \"head_size\");\n    int64_t num_heads = json::Lookup<int64_t>(compile_time_model_config, \"num_heads\");\n    int64_t num_layers = json::Lookup<int64_t>(compile_time_model_config, \"num_hidden_layers\");\n    int64_t hidden_size = json::Lookup<int64_t>(compile_time_model_config, \"hidden_size\");\n    // - Calculate RNN state memory usage.\n    rnn_state_base_bytes += (max_num_sequence * hidden_size * num_layers * 2 * 2 +\n                             max_num_sequence * num_heads * head_size * head_size * num_layers * 2);\n    model_workspace_bytes += prefill_chunk_size * 4 + max_num_sequence * 4 +\n                             (prefill_chunk_size * 2 + max_num_sequence) * hidden_size * 2;\n    logit_processor_workspace_bytes +=\n        max_num_sequence * 20 + max_num_sequence * vocab_size * 16.125;\n  }\n  int64_t gpu_size_bytes = TotalDetectGlobalMemory(device);\n  // Compute the maximum history size length under the GPU memory budget.\n  int64_t model_max_history_size = static_cast<int>((gpu_size_bytes * gpu_memory_utilization  //\n                                                     - params_bytes                           //\n                                                     - temp_buffer_bytes                      //\n                                                     - model_workspace_bytes                  //\n                                                     - logit_processor_workspace_bytes) /\n                                                    rnn_state_base_bytes);\n  if (model_max_history_size <= 0) {\n    return Result<InferrableEngineConfig>::Error(\n        \"Insufficient GPU memory error: \"\n        \"The available single GPU memory is \" +\n        BytesToMegabytesString(gpu_size_bytes * gpu_memory_utilization) +\n        \" MB, \"\n        \"which is less than the sum of model weight size (\" +\n        BytesToMegabytesString(params_bytes) + \" MB) and temporary buffer size (\" +\n        BytesToMegabytesString(\n            (temp_buffer_bytes + model_workspace_bytes + logit_processor_workspace_bytes)) +\n        \" MB). \"\n        \"If the model weight size is too large, please use quantization. \"\n        \"If the temporary buffer size is too large, please use a smaller `--prefill-chunk-size` in \"\n        \"`mlc_llm gen_config`.\");\n  }\n  if (!init_config.max_history_size.has_value()) {\n    inferred_config.max_history_size = model_max_history_size;\n  } else {\n    inferred_config.max_history_size =\n        std::min(inferred_config.max_history_size.value(), model_max_history_size);\n  }\n  if (verbose) {\n    LOG(INFO) << \"The actual engine mode is \\\"\" << EngineModeToString(mode)\n              << \"\\\". So max batch size is \" << inferred_config.max_num_sequence.value()\n              << \", max RNN state token capacity is \"\n              << inferred_config.max_total_sequence_length.value() << \", prefill chunk size is \"\n              << inferred_config.prefill_chunk_size.value() << \".\";\n    LOG(INFO) << \"Estimated total single GPU memory usage: \"\n              << BytesToMegabytesString(params_bytes + temp_buffer_bytes +\n                                        inferred_config.max_history_size.value() *\n                                            rnn_state_base_bytes)\n              << \" MB (Parameters: \" << BytesToMegabytesString(params_bytes) << \" MB. RNN state: \"\n              << BytesToMegabytesString(inferred_config.max_history_size.value() *\n                                        rnn_state_base_bytes)\n              << \" MB. Temporary buffer: \"\n              << BytesToMegabytesString(model_workspace_bytes + logit_processor_workspace_bytes +\n                                        temp_buffer_bytes)\n              << \" MB). The actual usage might be slightly larger than the estimated number.\";\n  }\n\n  return Result<InferrableEngineConfig>::Ok(inferred_config);\n}\n\n/****************** Config utils ******************/\n\nResult<bool> ModelsUseKVCache(const std::vector<tvm::ffi::json::Object>& model_configs) {\n  TVM_FFI_ICHECK_GE(model_configs.size(), 1);\n  std::string model_type = json::Lookup<std::string>(model_configs[0], \"model_type\");\n  bool use_kv_cache = model_type.find(\"rwkv\") == std::string::npos;\n  for (int i = 1; i < static_cast<int>(model_configs.size()); ++i) {\n    if ((json::Lookup<std::string>(model_configs[i], \"model_type\").find(\"rwkv\") ==\n         std::string::npos) != use_kv_cache) {\n      return Result<bool>::Error(\n          \"Invalid models in EngineConfig. Models must be all RNN model or none model is RNN \"\n          \"model.\");\n    }\n  }\n  return Result<bool>::Ok(use_kv_cache);\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/config.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/config.h\n */\n#ifndef MLC_LLM_SERVE_CONFIG_H_\n#define MLC_LLM_SERVE_CONFIG_H_\n\n#include <tvm/ffi/container/array.h>\n#include <tvm/ffi/extra/json.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/ffi/string.h>\n#include <tvm/runtime/device_api.h>\n#include <tvm/runtime/int_tuple.h>\n#include <tvm/runtime/object.h>\n\n#include <optional>\n\n#include \"../metadata/model.h\"\n#include \"../support/result.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing namespace tvm;\nusing namespace tvm::runtime;\nusing tvm::ffi::Array;\nusing tvm::ffi::Optional;\nusing tvm::ffi::String;\n\n/****************** GenerationConfig ******************/\n\n/*! \\brief The response format of a request. */\nstruct ResponseFormat {\n  String type = \"text\";\n  Optional<String> schema = std::nullopt;\n  /*!\n   * \\brief Create debug config from JSON.\n   * \\param config_json The json string for generation config\n   * \\returns The converted result.\n   */\n  static Result<ResponseFormat> FromJSON(const tvm::ffi::json::Object& config_json);\n\n  /**\n   * \\return serialized json value of the config.\n   */\n  tvm::ffi::json::Object AsJSON() const;\n};\n\nenum class SpecialRequestKind : int {\n  kNone = 0,\n  kQueryEngineMetrics = 1,\n};\n\nenum class DisaggRequestKind : int {\n  kNone = 0,\n  kPrepareReceive = 1,\n  kRemoteSend = 2,\n  kStartGeneration = 3,\n};\n\n/*! \\brief Controls the behavior of inference with grammar constraint. */\nenum class GrammarExecutionMode : int {\n  /*! \\brief If grammar is provided for a request, use the grammar to constrain the output token. */\n  kConstraint = 0,\n  /*! \\brief If grammar is provided for a request, not only constrain the output, but also use the\n   * jump-forward decoding to predict the next tokens. This is the default option. */\n  kJumpForward = 1,\n};\n\n/*! \\brief The config for disaggregation requests. */\nclass DisaggConfig {\n public:\n  DisaggRequestKind kind = DisaggRequestKind::kNone;\n  std::vector<IntTuple> kv_append_metadata;\n  // \"kv_window_begin\" and \"kv_window_end\" denote the KV interval of interests.\n  // \"kv_window_end\" supports Python style negative indexing.\n  // The concrete meaning varies for different special request kind:\n  // - For \"prepare_receive\", the begin is always 0, and \"[0:end]\" denotes\n  // the KV range to prefill on a prefill instance.\n  // - For \"remote_send\", \"[begin:end]\" means the KV range to compute prefill\n  // and send to the decode instance.\n  // - For \"start_generation\", the end is always nullopt, and \"[begin:]\" denotes\n  // the KV range to prefill locally on the decode instance.\n  std::optional<int> kv_window_begin = std::nullopt;\n  std::optional<int> kv_window_end = std::nullopt;\n  std::optional<int> dst_group_offset = std::nullopt;\n\n  static Result<DisaggConfig> FromJSON(const tvm::ffi::json::Object& config_json);\n  tvm::ffi::json::Object AsJSON() const;\n};\n\n/*! \\brief The debug configuration of a request. */\nclass DebugConfig {\n public:\n  bool ignore_eos = false;\n  bool pinned_system_prompt = false;\n  SpecialRequestKind special_request = SpecialRequestKind::kNone;\n  /*! \\brief The grammar execution mode. */\n  GrammarExecutionMode grammar_execution_mode = GrammarExecutionMode::kJumpForward;\n  DisaggConfig disagg_config;\n\n  /*!\n   * \\brief Create debug config from JSON.\n   * \\param config_json The json string for generation config\n   * \\returns The converted result.\n   */\n  static Result<DebugConfig> FromJSON(const tvm::ffi::json::Object& config_json);\n\n  /**\n   * \\return serialized json value of the config.\n   */\n  tvm::ffi::json::Object AsJSON() const;\n};\n\n/*! \\brief The generation configuration of a request. */\nclass GenerationConfigNode : public Object {\n public:\n  int n = 1;\n  double temperature = 1.0;\n  double top_p = 1.0;\n  double frequency_penalty = 0.0;\n  double presence_penalty = 0.0;\n  double repetition_penalty = 1.0;\n  bool logprobs = false;\n  int top_logprobs = 0;\n  std::vector<std::pair<int, float>> logit_bias;\n  int seed;\n  // -1 means infinite\n  int max_tokens = -1;\n  Array<String> stop_strs;\n  std::vector<int> stop_token_ids;\n\n  ResponseFormat response_format;\n  DebugConfig debug_config;\n\n  tvm::ffi::json::Object AsJSON() const;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<GenerationConfigNode>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  TVM_FFI_DECLARE_OBJECT_INFO(\"mlc.serve.GenerationConfig\", GenerationConfigNode, Object);\n};\n\nclass GenerationConfig : public ObjectRef {\n public:\n  /*!\n   * \\brief Run validation of generation config and ensure values are in bound.\n   * \\return The validtaed Generation config or error.\n   */\n  static Result<GenerationConfig> Validate(GenerationConfig cfg);\n\n  /*!\n   * \\brief Create generation config from JSON.\n   * \\param config_json The json string for generation config\n   * \\param default_config The default config\n   */\n  static Result<GenerationConfig> FromJSON(const tvm::ffi::json::Object& config_json,\n                                           const GenerationConfig& default_config);\n\n  /*! \\brief Get the default generation config from the model config. */\n  static GenerationConfig GetDefaultFromModelConfig(const tvm::ffi::json::Object& json);\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GenerationConfig, ObjectRef, GenerationConfigNode);\n};\n\n/****************** Engine config ******************/\n\n/*!\n * \\brief The engine mode in MLC LLM.\n * We provide three preset modes: \"local\", \"interactive\" and \"server\".\n * The default mode is \"local\".\n * The choice of mode decides the values of \"max_batch_size\", \"max_total_sequence_length\"\n * and \"prefill_chunk_size\" when they are not explicitly specified.\n * 1. Mode \"local\" refers to the local server deployment which has low\n * request concurrency. So the max batch size will be set to 4, and max\n * total sequence length and prefill chunk size are set to the context\n * window size (or sliding window size) of the model.\n * 2. Mode \"interactive\" refers to the interactive use of server, which\n * has at most 1 concurrent request. So the max batch size will be set to 1,\n * and max total sequence length and prefill chunk size are set to the context\n * window size (or sliding window size) of the model.\n * 3. Mode \"server\" refers to the large server use case which may handle\n * many concurrent request and want to use GPU memory as much as possible.\n * In this mode, we will automatically infer the largest possible max batch\n * size and max total sequence length.\n */\nenum class EngineMode : int {\n  kLocal = 0,\n  kInteractive = 1,\n  kServer = 2,\n};\n\n/*! \\brief The prefix cache mode. */\nenum class PrefixCacheMode : int {\n  /*! \\brief Disable prefix cache. */\n  kDisable = 0,\n  /*! \\brief The paged radix tree based prefix cache mode. */\n  kRadix = 1,\n};\n\n/*! \\brief The speculative mode. */\nenum class SpeculativeMode : int {\n  /*! \\brief Disable speculative decoding. */\n  kDisable = 0,\n  /*! \\brief The normal speculative decoding (small draft) mode. */\n  kSmallDraft = 1,\n  /*! \\brief The eagle-style speculative decoding. */\n  kEagle = 2,\n  /*! \\brief The Medusa-style speculative decoding. */\n  kMedusa = 3,\n};\n\n/*! \\brief The prefill mode. */\nenum class PrefillMode : int {\n  /*! \\brief Only chunked prefill is enabled. */\n  kChunked = 0,\n  /*!\n   * \\brief The hybrid prefill or split-fuse prefill is enabled, some decode steps will be fused\n   * to prefill\n   */\n  kHybrid = 1,\n};\n\nclass InferrableEngineConfig;\n\n/*! \\brief The configuration of engine execution config. */\nclass EngineConfigNode : public Object {\n public:\n  /*************** Models ***************/\n\n  /*! \\brief The path to the model directory. */\n  String model;\n  /*! \\brief The path or identifier to the model library. */\n  String model_lib;\n  /*! \\brief The path to the additional models' directories. */\n  Array<String> additional_models;\n  /*! \\brief The path to the additional models' libraries. */\n  Array<String> additional_model_libs;\n\n  /*************** KV cache config and engine capacities ***************/\n\n  /*!\n   * \\brief The engine mode in MLC LLM.\n   * \\sa EngineMode\n   */\n  EngineMode mode = EngineMode::kLocal;\n  /*!\n   * \\brief A number in (0, 1) denoting the fraction of GPU memory used by the server in total.\n   * It is used to infer to maximum possible KV cache capacity.\n   * When it is unspecified, it defaults to 0.85.\n   * Under mode \"local\" or \"interactive\", the actual memory usage may be\n   * significantly smaller than this number. Under mode \"server\", the actual\n   * memory usage may be slightly larger than this number.\n   */\n  float gpu_memory_utilization = 0.85;\n  /*! \\brief The number of consecutive tokens handled in each page in paged KV cache. */\n  int kv_cache_page_size = 16;\n  /*!\n   * \\brief The maximum number of sequences that are allowed to be\n   * processed by the KV cache at any time.\n   */\n  int max_num_sequence = 4;\n  /*! \\brief The maximum length allowed for a single sequence in the engine. */\n  int64_t max_total_sequence_length = 4096;\n  /*!\n   * \\brief The maximum total number of tokens whose KV data are allowed\n   * to exist in the KV cache at any time.\n   */\n  int64_t max_single_sequence_length = 4096;\n  /*! \\brief The maximum total sequence length in a prefill. */\n  int64_t prefill_chunk_size = 1024;\n  /*! \\brief The maximum history size for RNN state. KV cache does not need this. */\n  int max_history_size = 0;\n\n  /*************** Prefix cache ***************/\n\n  /*! \\brief The prefix cache mode. */\n  PrefixCacheMode prefix_cache_mode = PrefixCacheMode::kRadix;\n  /*! \\brief The maximum number of recycling sequences in prefix cache, default as max_num_sequence.\n   * And set 0 to disable prefix cache, set -1 to have infinite capacity prefix cache. */\n  int prefix_cache_max_num_recycling_seqs = -1;\n\n  /*************** Speculative decoding ***************/\n\n  /*! \\brief The speculative mode. */\n  SpeculativeMode speculative_mode = SpeculativeMode::kDisable;\n  /*!\n   * \\brief The number of tokens to generate in speculative proposal (draft).\n   * Being 0 means to enable adaptive speculative mode, where the draft length\n   * will be automatically adjusted based on engine state.\n   */\n  int spec_draft_length = 0;\n  /*! \\brief The number of tokens to generate in speculative tree decoding */\n  int spec_tree_width = 1;\n\n  /*************** Prefill mode ***************/\n\n  /*! \\brief The prefill mode. */\n  PrefillMode prefill_mode = PrefillMode::kHybrid;\n\n  /*************** Debug ***************/\n  bool verbose = false;\n\n  String AsJSONString() const;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<EngineConfigNode>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO(\"mlc.serve.EngineConfig\", EngineConfigNode, Object);\n};\n\nclass EngineConfig : public ObjectRef {\n public:\n  /*! \\brief Create EngineConfig from JSON object and inferred config. */\n  static EngineConfig FromJSONAndInferredConfig(const tvm::ffi::json::Object& json,\n                                                const InferrableEngineConfig& inferred_config);\n\n  /*!\n   * \\brief Get all the models and model libs from the JSON string for engine initialization.\n   * \\return The parsed models/model libs from config or error message.\n   */\n  static Result<std::vector<std::pair<std::string, std::string>>>\n  GetModelsAndModelLibsFromJSONString(const std::string& json_str);\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(EngineConfig, ObjectRef, EngineConfigNode);\n};\n\n/*! \\brief A subset of engine config that is inferrable. */\nstruct InferrableEngineConfig {\n  std::optional<int64_t> max_num_sequence;\n  std::optional<int64_t> max_total_sequence_length;\n  std::optional<int64_t> max_single_sequence_length;\n  std::optional<int64_t> prefill_chunk_size;\n  std::optional<int64_t> max_history_size;\n\n  /*! \\brief Infer the config for KV cache from a given initial config. */\n  static Result<InferrableEngineConfig> InferForKVCache(\n      EngineMode mode, Device device, double gpu_memory_utilization,\n      const std::vector<tvm::ffi::json::Object>& model_configs,\n      const std::vector<ModelMetadata>& model_metadata, InferrableEngineConfig init_config,\n      bool verbose);\n  /*! \\brief Infer the config for RNN state from a given initial config. */\n  static Result<InferrableEngineConfig> InferForRNNState(\n      EngineMode mode, Device device, double gpu_memory_utilization,\n      const std::vector<tvm::ffi::json::Object>& model_configs,\n      const std::vector<ModelMetadata>& model_metadata, InferrableEngineConfig init_config,\n      bool verbose);\n};\n\n/****************** Config utils ******************/\n\n/*! \\brief Check if the models use KV cache or RNN state. */\nResult<bool> ModelsUseKVCache(const std::vector<tvm::ffi::json::Object>& model_configs);\n\ninline std::string EngineModeToString(EngineMode mode) {\n  if (mode == EngineMode::kLocal) {\n    return \"local\";\n  } else if (mode == EngineMode::kInteractive) {\n    return \"interactive\";\n  } else if (mode == EngineMode::kServer) {\n    return \"server\";\n  } else {\n    LOG(FATAL) << \"Invalid engine mode: \" << static_cast<int>(mode);\n    throw;\n  }\n}\n\ninline EngineMode EngineModeFromString(const std::string& mode) {\n  if (mode == \"local\") {\n    return EngineMode::kLocal;\n  } else if (mode == \"interactive\") {\n    return EngineMode::kInteractive;\n  } else if (mode == \"server\") {\n    return EngineMode::kServer;\n  } else {\n    LOG(FATAL) << \"Invalid engine mode string: \" << mode;\n    throw;\n  }\n}\n\ninline std::string PrefixCacheModeToString(PrefixCacheMode prefix_cache_mode) {\n  if (prefix_cache_mode == PrefixCacheMode::kDisable) {\n    return \"disable\";\n  } else if (prefix_cache_mode == PrefixCacheMode::kRadix) {\n    return \"radix\";\n  } else {\n    LOG(FATAL) << \"Invalid prefix cache mode: \" << static_cast<int>(prefix_cache_mode);\n  }\n}\n\ninline PrefixCacheMode PrefixCacheModeFromString(const std::string& prefix_cache_mode) {\n  if (prefix_cache_mode == \"disable\") {\n    return PrefixCacheMode::kDisable;\n  } else if (prefix_cache_mode == \"radix\") {\n    return PrefixCacheMode::kRadix;\n  } else {\n    LOG(FATAL) << \"Invalid prefix cache mode string: \" << prefix_cache_mode;\n    throw;\n  }\n}\n\ninline std::string SpeculativeModeToString(SpeculativeMode speculative_mode) {\n  if (speculative_mode == SpeculativeMode::kDisable) {\n    return \"disable\";\n  } else if (speculative_mode == SpeculativeMode::kSmallDraft) {\n    return \"small_draft\";\n  } else if (speculative_mode == SpeculativeMode::kEagle) {\n    return \"eagle\";\n  } else if (speculative_mode == SpeculativeMode::kMedusa) {\n    return \"medusa\";\n  } else {\n    LOG(FATAL) << \"Invalid speculative mode: \" << static_cast<int>(speculative_mode);\n  }\n}\n\ninline SpeculativeMode SpeculativeModeFromString(const std::string& speculative_mode) {\n  if (speculative_mode == \"disable\") {\n    return SpeculativeMode::kDisable;\n  } else if (speculative_mode == \"small_draft\") {\n    return SpeculativeMode::kSmallDraft;\n  } else if (speculative_mode == \"eagle\") {\n    return SpeculativeMode::kEagle;\n  } else if (speculative_mode == \"medusa\") {\n    return SpeculativeMode::kMedusa;\n  } else {\n    LOG(FATAL) << \"Invalid speculative mode string: \" << speculative_mode;\n    throw;\n  }\n}\n\ninline std::string PrefillModeToString(PrefillMode prefill_mode) {\n  if (prefill_mode == PrefillMode::kChunked) {\n    return \"chunked\";\n  } else if (prefill_mode == PrefillMode::kHybrid) {\n    return \"hybrid\";\n  } else {\n    LOG(FATAL) << \"Invalid prefill mode: \" << static_cast<int>(prefill_mode);\n  }\n}\n\ninline PrefillMode PrefillModeFromString(const std::string& prefill_mode) {\n  if (prefill_mode == \"chunked\") {\n    return PrefillMode::kChunked;\n  } else if (prefill_mode == \"hybrid\") {\n    return PrefillMode::kHybrid;\n  } else {\n    LOG(FATAL) << \"Invalid prefill mode string: \" << prefill_mode;\n    throw;\n  }\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_CONFIG_H_\n"
  },
  {
    "path": "cpp/serve/data.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/data.cc\n */\n#include \"data.h\"\n\n#include <tvm/ffi/function.h>\n#include <tvm/ffi/reflection/registry.h>\n\n#include \"model.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  DataNode::RegisterReflection();\n  TextDataNode::RegisterReflection();\n  TokenDataNode::RegisterReflection();\n  ImageDataNode::RegisterReflection();\n  RequestStreamOutputObj::RegisterReflection();\n}\n\n/****************** Data ******************/\n\nstd::pair<Array<Data>, Array<Data>> SplitData(const Array<Data>& original_data, int total_length,\n                                              int split_pos) {\n  TVM_FFI_ICHECK_GE(split_pos, 0);\n  TVM_FFI_ICHECK_GE(total_length, split_pos)\n      << \"Cannot truncate when the current length is already less than the target length\";\n  std::vector<Data> lhs(original_data.begin(), original_data.end());\n  std::vector<Data> rhs;\n  while (total_length > split_pos) {\n    TVM_FFI_ICHECK(!lhs.empty());\n    Data last_data = lhs.back();\n    int last_data_length = last_data->GetLength();\n    TVM_FFI_ICHECK_GE(total_length - last_data_length, 0);\n    if (total_length - last_data_length >= split_pos) {\n      // Pop the entire last data.\n      rhs.push_back(lhs.back());\n      lhs.pop_back();\n      total_length -= last_data_length;\n      continue;\n    }\n    // Partially truncate the last data.\n    const auto* token_data = last_data.as<TokenDataNode>();\n    TVM_FFI_ICHECK(token_data != nullptr) << \"Only TokenData supports partial truncation.\";\n    int length_to_truncate = total_length - split_pos;\n    TVM_FFI_ICHECK_GT(length_to_truncate, 0);\n    TVM_FFI_ICHECK_LT(length_to_truncate, last_data_length);\n    TokenData lhs_token_data(\n        IntTuple{token_data->token_ids.begin(), token_data->token_ids.end() - length_to_truncate});\n    TokenData rhs_token_data(\n        IntTuple{token_data->token_ids.end() - length_to_truncate, token_data->token_ids.end()});\n    TVM_FFI_ICHECK_EQ(total_length - last_data_length + lhs_token_data->GetLength(), split_pos);\n    lhs.pop_back();\n    lhs.push_back(lhs_token_data);\n    rhs.push_back(rhs_token_data);\n    std::reverse(rhs.begin(), rhs.end());\n    total_length = split_pos;\n  }\n  return {lhs, rhs};\n}\n\n/****************** TextData ******************/\n\nTextData::TextData(String text) {\n  ObjectPtr<TextDataNode> n = tvm::ffi::make_object<TextDataNode>();\n  n->text = std::move(text);\n  data_ = std::move(n);\n}\n\nint TextDataNode::GetLength() const {\n  LOG(FATAL) << \"\\\"GetLength\\\" for TextData is not supported. \"\n                \"Please tokenize the text and construct a TokenData object.\";\n}\n\nObjectRef TextDataNode::GetEmbedding(Model model, ObjectRef* dst, int offset) const {\n  LOG(FATAL) << \"\\\"GetEmbedding\\\" for TextData is not supported. \"\n                \"Please tokenize the text and construct a TokenData object.\";\n}\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  namespace refl = tvm::ffi::reflection;\n  refl::GlobalDef()\n      .def(\"mlc.serve.TextData\", [](String text) { return TextData(std::move(text)); })\n      .def(\"mlc.serve.TextDataGetTextString\", [](TextData data) { return data->text; });\n}\n\n/****************** TokenData ******************/\n\nTokenData::TokenData(IntTuple token_ids) {\n  ObjectPtr<TokenDataNode> n = tvm::ffi::make_object<TokenDataNode>();\n  n->token_ids = std::move(token_ids);\n  data_ = std::move(n);\n}\n\nTokenData::TokenData(std::vector<int32_t> token_ids) {\n  ObjectPtr<TokenDataNode> n = tvm::ffi::make_object<TokenDataNode>();\n  n->token_ids = IntTuple(token_ids.begin(), token_ids.end());\n  data_ = std::move(n);\n}\n\nint TokenDataNode::GetLength() const { return token_ids.size(); }\n\nObjectRef TokenDataNode::GetEmbedding(Model model, ObjectRef* dst, int offset) const {\n  return model->TokenEmbed(token_ids, dst, offset);\n}\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  namespace refl = tvm::ffi::reflection;\n  refl::GlobalDef()\n      .def_packed(\"mlc.serve.TokenData\",\n                  [](ffi::PackedArgs args, ffi::Any* rv) {\n                    std::vector<int32_t> token_ids;\n                    token_ids.reserve(args.size());\n                    for (int i = 0; i < args.size(); i++) {\n                      token_ids.push_back(args[i].cast<int32_t>());\n                    }\n                    *rv = TokenData(std::move(token_ids));\n                  })\n      .def(\"mlc.serve.TokenDataGetTokenIds\", [](TokenData data) { return data->token_ids; });\n}\n\n/****************** ImageData ******************/\n\nImageData::ImageData(Tensor image, int embed_size) {\n  ObjectPtr<ImageDataNode> n = tvm::ffi::make_object<ImageDataNode>();\n  n->image = std::move(image);\n  n->embed_size = embed_size;\n  data_ = std::move(n);\n}\n\nint ImageDataNode::GetLength() const { return embed_size; }\n\nObjectRef ImageDataNode::GetEmbedding(Model model, ObjectRef* dst, int offset) const {\n  return model->ImageEmbed(image, dst, offset);\n}\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  namespace refl = tvm::ffi::reflection;\n  refl::GlobalDef()\n      .def(\"mlc.serve.ImageData\",\n           [](Tensor image, int embed_size) { return ImageData(std::move(image), embed_size); })\n      .def(\"mlc.serve.ImageDataGetImage\", [](ImageData data) { return data->image; });\n}\n\n/****************** SampleResult ******************/\n\n/*! \\brief Convert a single token with probability to JSON string. */\ninline void TokenToLogProbJSON(const Tokenizer& tokenizer, const TokenProbPair& token_prob,\n                               std::ostringstream* os) {\n  const std::string& token = tokenizer->PostProcessedTokenTable()[token_prob.first];\n\n  (*os) << \"\\\"token\\\": \\\"\";\n  for (char ch : token) {\n    if (ch >= 33 && ch <= 126) {\n      // The character is in ASCII visible range.\n      // Handle escape characters in JSON.\n      if (ch == '\"') {\n        (*os) << \"\\\\\\\"\";\n      } else if (ch == '\\\\') {\n        (*os) << \"\\\\\\\\\";\n      } else {\n        (*os) << ch;\n      }\n    }\n  }\n  (*os) << \"\\\", \";\n  (*os) << \"\\\"logprob\\\": \" << std::log(std::max(token_prob.second, 1e-10f)) << \", \";\n  (*os) << \"\\\"bytes\\\": [\";\n  int token_len = token.size();\n  for (int pos = 0; pos < token_len; ++pos) {\n    (*os) << static_cast<int>(static_cast<unsigned char>(token[pos]));\n    if (pos != token_len - 1) {\n      (*os) << \", \";\n    }\n  }\n  (*os) << \"]\";\n}\n\nint32_t SampleResult::GetTokenId() const { return this->sampled_token_id.first; }\n\nstd::string SampleResult::GetLogProbJSON(const Tokenizer& tokenizer, bool logprob) const {\n  TVM_FFI_ICHECK(top_prob_tokens.empty() || logprob);\n  if (!logprob) {\n    // Logprob is not needed.\n    return \"\";\n  }\n\n  std::ostringstream os;\n  os << \"{\";\n  // - Convert the sampled token to JSON.\n  TokenToLogProbJSON(tokenizer, sampled_token_id, &os);\n  // - Convert the tokens with top probabilities.\n  os << \", \\\"top_logprobs\\\": [\";\n  int num_top = top_prob_tokens.size();\n  for (int i = 0; i < num_top; ++i) {\n    os << \"{\";\n    TokenToLogProbJSON(tokenizer, top_prob_tokens[i], &os);\n    os << \"}\";\n    if (i != num_top - 1) {\n      os << \", \";\n    }\n  }\n  os << \"]}\";\n  return os.str();\n}\n\n/****************** RequestStreamOutput ******************/\n\nRequestStreamOutput::RequestStreamOutput(\n    String request_id, std::vector<std::vector<int64_t>> group_delta_token_ids,\n    std::optional<std::vector<std::vector<String>>> group_delta_logprob_json_strs,\n    std::vector<Optional<String>> group_finish_reason,\n    std::vector<String> group_extra_prefix_string) {\n  ObjectPtr<RequestStreamOutputObj> n = tvm::ffi::make_object<RequestStreamOutputObj>();\n  n->request_id = std::move(request_id);\n  n->group_delta_token_ids = std::move(group_delta_token_ids);\n  n->group_delta_logprob_json_strs = std::move(group_delta_logprob_json_strs);\n  n->group_finish_reason = std::move(group_finish_reason);\n  n->group_extra_prefix_string = std::move(group_extra_prefix_string);\n  data_ = std::move(n);\n}\n\nRequestStreamOutput RequestStreamOutput::Usage(String request_id,\n                                               String request_final_usage_json_str) {\n  ObjectPtr<RequestStreamOutputObj> n = tvm::ffi::make_object<RequestStreamOutputObj>();\n  n->request_id = std::move(request_id);\n  n->request_final_usage_json_str = std::move(request_final_usage_json_str);\n  return RequestStreamOutput(n);\n}\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  namespace refl = tvm::ffi::reflection;\n  refl::GlobalDef().def(\"mlc.serve.RequestStreamOutputUnpack\", [](RequestStreamOutput output) {\n    TVM_FFI_ICHECK(!output->unpacked)\n        << \"One RequestStreamOutput can be unpacked for at most once.\";\n    std::vector<IntTuple> group_delta_token_ids;\n    std::vector<Array<String>> group_delta_logprob_json_strs;\n    group_delta_token_ids.reserve(output->group_delta_token_ids.size());\n    if (output->group_delta_logprob_json_strs.has_value()) {\n      group_delta_logprob_json_strs.reserve(output->group_delta_token_ids.size());\n    }\n    for (int i = 0; i < static_cast<int>(output->group_delta_token_ids.size()); ++i) {\n      group_delta_token_ids.push_back(output->group_delta_token_ids[i]);\n      if (output->group_delta_logprob_json_strs.has_value()) {\n        group_delta_logprob_json_strs.push_back(output->group_delta_logprob_json_strs.value()[i]);\n      }\n    }\n    Array<Any> ret = {output->request_id,\n                      Array<IntTuple>(std::move(group_delta_token_ids)),\n                      output->group_delta_logprob_json_strs.has_value()\n                          ? Array<Array<String>>(std::move(group_delta_logprob_json_strs))\n                          : Optional<Array<Array<String>>>(),\n                      Array<Optional<String>>(output->group_finish_reason),\n                      output->request_final_usage_json_str,\n                      Array<String>(output->group_extra_prefix_string)};\n    output->unpacked = true;\n    return ret;\n  });\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/data.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/data.h\n */\n#ifndef MLC_LLM_SERVE_DATA_H_\n#define MLC_LLM_SERVE_DATA_H_\n\n#include <tvm/ffi/container/array.h>\n#include <tvm/ffi/container/shape.h>\n#include <tvm/ffi/optional.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/ffi/string.h>\n#include <tvm/node/cast.h>\n#include <tvm/runtime/int_tuple.h>\n#include <tvm/runtime/object.h>\n#include <tvm/runtime/tensor.h>\n\n#include <atomic>\n#include <optional>\n\n#include \"../tokenizers/tokenizers.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing namespace tvm::runtime;\nusing tvm::ffi::Optional;\n\nclass Model;\n\n/****************** DataNode ******************/\n\n/*! \\brief The base class of multi-modality data (text, tokens, embedding, etc). */\nclass DataNode : public Object {\n public:\n  /*! \\brief Get the length (equivalent number of tokens) of the data. */\n  virtual int GetLength() const = 0;\n\n  /*!\n   * \\brief Compute the embedding of this data with regard to the input model.\n   * When the input destination pointer is not nullptr, it in-place writes the\n   * embedding into the input destination array at the given offset.\n   * Otherwise, the embeddings will be directly returned back.\n   * \\param model The model to take embeddings from.\n   * \\param dst The destination array of the embedding lookup.\n   * \\param offset The token offset where the computed embeddings will be written\n   * into the destination array.\n   * \\return The updated destination embedding array or the computed embeddings.\n   * \\note When `dst` is nullptr, we require `offset` to be 0.\n   */\n  virtual ObjectRef GetEmbedding(Model model, ObjectRef* dst = nullptr, int offset = 0) const = 0;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<DataNode>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const uint32_t _type_child_slots = 3;\n  TVM_FFI_DECLARE_OBJECT_INFO(\"mlc.serve.Data\", DataNode, Object);\n};\n\nclass Data : public ObjectRef {\n public:\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Data, ObjectRef, DataNode);\n};\n\n/*! \\brief Split the given data array into two arrays at the \"split_pos\" position. */\nstd::pair<Array<Data>, Array<Data>> SplitData(const Array<Data>& original_data, int total_length,\n                                              int split_pos);\n\n/****************** TextDataNode ******************/\n\n/*! \\brief The class of text data, containing a text string. */\nclass TextDataNode : public DataNode {\n public:\n  /*! \\brief The text string. */\n  tvm::ffi::String text;\n\n  int GetLength() const final;\n  ObjectRef GetEmbedding(Model model, ObjectRef* dst = nullptr, int offset = 0) const final;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<TextDataNode>();\n  }\n\n  TVM_FFI_DECLARE_OBJECT_INFO_FINAL(\"mlc.serve.TextData\", TextDataNode, DataNode);\n};\n\nclass TextData : public Data {\n public:\n  explicit TextData(String text);\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TextData, Data, TextDataNode);\n};\n\n/****************** TokenDataNode ******************/\n\n/*! \\brief The class of token data, containing a list of token ids. */\nclass TokenDataNode : public DataNode {\n public:\n  /*! \\brief The token ids. */\n  IntTuple token_ids;\n\n  int GetLength() const final;\n  ObjectRef GetEmbedding(Model model, ObjectRef* dst = nullptr, int offset = 0) const final;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<TokenDataNode>();\n  }\n\n  TVM_FFI_DECLARE_OBJECT_INFO_FINAL(\"mlc.serve.TokenData\", TokenDataNode, DataNode);\n};\n\nclass TokenData : public Data {\n public:\n  explicit TokenData(IntTuple token_ids);\n\n  explicit TokenData(std::vector<int32_t> token_ids);\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TokenData, Data, TokenDataNode);\n};\n\n/****************** ImageDataNode ******************/\n\n/*! \\brief The class of image data, containing a 3D array of pixel values. */\nclass ImageDataNode : public DataNode {\n public:\n  /*! \\brief The pixel values. */\n  Tensor image;\n  int embed_size;\n\n  int GetLength() const final;\n  ObjectRef GetEmbedding(Model model, ObjectRef* dst = nullptr, int offset = 0) const final;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<ImageDataNode>();\n  }\n\n  TVM_FFI_DECLARE_OBJECT_INFO_FINAL(\"mlc.serve.ImageData\", ImageDataNode, DataNode);\n};\n\nclass ImageData : public Data {\n public:\n  explicit ImageData(Tensor image, int embed_size);\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ImageData, Data, ImageDataNode);\n};\n\n/****************** SampleResult ******************/\n\n// The pair of a token id and its probability in sampling.\nusing TokenProbPair = std::pair<int32_t, float>;\n\n/*!\n * \\brief The class of sampler's sampling result.\n * It's not a TVM object since it will not be used directly on Python side.\n */\nstruct SampleResult {\n  /*! \\brief The token id and probability of the sampled token. */\n  TokenProbPair sampled_token_id;\n  /*! \\brief The token id and probability of the tokens with top probabilities. */\n  std::vector<TokenProbPair> top_prob_tokens;\n\n  /*! \\brief Get the sampled token id. */\n  int32_t GetTokenId() const;\n\n  /*!\n   * \\brief Get the logprob JSON string of this token with regard\n   * to OpenAI API at https://platform.openai.com/docs/api-reference/chat/object.\n   * \\param tokenizer The tokenizer for token table lookup.\n   * \\param logprob A boolean indicating if need to return log probability.\n   * \\return A JSON string that conforms to the logprob spec in OpenAI API.\n   */\n  std::string GetLogProbJSON(const Tokenizer& tokenizer, bool logprob) const;\n};\n\n/****************** RequestStreamOutput ******************/\n\n/*!\n * \\brief The generated delta request output that is streamed back\n * through callback stream function.\n *\n * \\note: This output object corresponds to parallel generated outputs when n != 1.\n *\n * For example, if n=2, then group_delta_token_ids[0] matches to the output stream 0\n * and group_delta_token_ids[1] matches to the output stream 1\n */\nclass RequestStreamOutputObj : public Object {\n public:\n  /*! \\brief The id of the request that the function is invoked for. */\n  String request_id;\n  /*!\n   * \\brief The new generated token ids since the last callback invocation\n   * for the input request.\n   */\n  std::vector<std::vector<int64_t>> group_delta_token_ids;\n  /*! \\brief The logprobs JSON strings of the new generated tokens since last invocation. */\n  std::optional<std::vector<std::vector<String>>> group_delta_logprob_json_strs;\n  /*!\n   * \\brief The finish reason of the request when it is finished,\n   * of None if the request has not finished yet.\n   */\n  std::vector<Optional<String>> group_finish_reason;\n  /*!\n   * \\brief The usage field of the response, this is global to all streams.\n   */\n  Optional<String> request_final_usage_json_str;\n\n  /*!\n   * \\brief The extra prefix string of all requests.\n   */\n  std::vector<String> group_extra_prefix_string;\n\n  std::atomic<bool> unpacked = false;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<RequestStreamOutputObj>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO(\"mlc.serve.RequestStreamOutput\", RequestStreamOutputObj, Object);\n};\n\n/*!\n * \\brief Managed reference to RequestStreamOutputObj.\n * \\sa RequestStreamOutputObj\n */\nclass RequestStreamOutput : public ObjectRef {\n public:\n  explicit RequestStreamOutput(\n      String request_id, std::vector<std::vector<int64_t>> group_delta_token_ids,\n      std::optional<std::vector<std::vector<String>>> group_delta_logprob_json_strs,\n      std::vector<Optional<String>> group_finish_reason,\n      std::vector<String> group_extra_prefix_string);\n\n  static RequestStreamOutput Usage(String request_id, String request_final_usage_json_str);\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RequestStreamOutput, ObjectRef,\n                                             RequestStreamOutputObj);\n};\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_DATA_H_\n"
  },
  {
    "path": "cpp/serve/draft_token_workspace_manager.cc",
    "content": "/*!\n * Copyright (c) 2023-2025 by Contributors\n * \\file serve/draft_token_workspace_manager.cc\n */\n\n#include \"draft_token_workspace_manager.h\"\n\n#include \"model.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nTVM_FFI_STATIC_INIT_BLOCK() { DraftTokenWorkspaceManagerObj::RegisterReflection(); }\n\nDraftTokenWorkspaceManagerObj::DraftTokenWorkspaceManagerObj(int max_num_tokens, int vocab_size,\n                                                             int hidden_size,\n                                                             DLDataType hidden_states_dtype,\n                                                             DLDevice device,\n                                                             const FunctionTable& ft)\n    : max_num_tokens_(max_num_tokens),\n      vocab_size_(vocab_size),\n      hidden_size_(hidden_size),\n      hidden_states_dtype_(hidden_states_dtype),\n      device_(device),\n      ft_(ft) {\n  free_slots_.resize(max_num_tokens);\n  std::iota(free_slots_.begin(), free_slots_.end(), 0);\n}\n\nvoid DraftTokenWorkspaceManagerObj::AllocSlots(int num_slots, std::vector<int>* result) {\n  TVM_FFI_ICHECK_LE(num_slots, free_slots_.size());\n  result->assign(free_slots_.rbegin(), free_slots_.rbegin() + num_slots);\n  free_slots_.resize(free_slots_.size() - num_slots);\n  for (int slot : (*result)) {\n    ref_count_[slot] = 1;\n  }\n}\n\nvoid DraftTokenWorkspaceManagerObj::AllocSlots(int num_slots,\n                                               const std::vector<int>& initial_ref_count,\n                                               std::vector<int>* result) {\n  TVM_FFI_ICHECK_LE(num_slots, free_slots_.size());\n  TVM_FFI_ICHECK_EQ(num_slots, initial_ref_count.size());\n  result->assign(free_slots_.rbegin(), free_slots_.rbegin() + num_slots);\n  free_slots_.resize(free_slots_.size() - num_slots);\n  for (int i = 0; i < num_slots; ++i) {\n    int slot = (*result)[i];\n    TVM_FFI_ICHECK(initial_ref_count[i] > 0);\n    ref_count_[slot] = initial_ref_count[i];\n  }\n}\n\nvoid DraftTokenWorkspaceManagerObj::FreeSlots(const std::vector<int>& slots) {\n  for (int slot : slots) {\n    if (--ref_count_.at(slot) == 0) {\n      free_slots_.push_back(slot);\n      ref_count_.erase(slot);\n    }\n  }\n}\n\nvoid DraftTokenWorkspaceManagerObj::AllocWorkspace(ModelWorkspace* workspace,\n                                                   bool require_hidden_states) {\n  workspace->draft_probs =\n      Tensor::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_);\n  workspace->draft_probs_storage =\n      Tensor::Empty({max_num_tokens_, vocab_size_}, DataType::Float(32), device_);\n  if (require_hidden_states) {\n    workspace->draft_hidden_states_storage = ft_.Empty(\n        {max_num_tokens_, hidden_size_}, hidden_states_dtype_, device_, /*worker0_only=*/false);\n  }\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/draft_token_workspace_manager.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/draft_token_workspace_manager.h\n */\n\n#ifndef MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_\n#define MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/runtime/device_api.h>\n\n#include <numeric>\n#include <optional>\n#include <vector>\n\n#include \"data.h\"\n#include \"function_table.h\"\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing tvm::Device;\nusing namespace tvm::runtime;\n\nstruct ModelWorkspace;\n\n/*!\n * \\brief Managing the workspace for draft token generation.\n *\n * The workspace is used to store the associated states for each draft token, including the\n * probability distribution of the draft token, the hidden states, etc. The workspace manager\n * maintains a pool of slots for the draft tokens to store the states.\n */\nclass DraftTokenWorkspaceManagerObj : public Object {\n public:\n  /*!\n   * \\brief Constructor\n   * \\param max_num_tokens The maximum number of draft tokens that can be stored in the workspace.\n   * \\param vocab_size The size of the vocabulary.\n   * \\param hidden_size The size of the hidden states.\n   * \\param hidden_states_dtype The data type of the hidden states.\n   * \\param device The device running the model.\n   * \\param ft The function table.\n   */\n  DraftTokenWorkspaceManagerObj(int max_num_tokens, int vocab_size, int hidden_size,\n                                DLDataType hidden_states_dtype, DLDevice device,\n                                const FunctionTable& ft);\n\n  /*!\n   * \\brief Allocate the workspace for draft tokens and update `ModelWorkspace` data structure.\n   * \\param workspace The object to stored the allocated draft token workspace.\n   * \\param require_hidden_states Whether to allocate workspace for the hidden states.\n   */\n  void AllocWorkspace(ModelWorkspace* workspace, bool require_hidden_states);\n\n  /*!\n   * \\brief Allocate slots for the draft tokens.\n   * \\param num_slots The number of slots to allocate.\n   * \\param result The vector to store the allocated slots.\n   */\n  void AllocSlots(int num_slots, std::vector<int>* result);\n\n  /*!\n   * \\brief Allocate slots for the draft tokens.\n   * \\param num_slots The number of slots to allocate.\n   * \\param initial_ref_count The initial reference count for each slot.\n   * \\param result The vector to store the allocated slots.\n   */\n  void AllocSlots(int num_slots, const std::vector<int>& initial_ref_count,\n                  std::vector<int>* result);\n\n  /*!\n   * \\brief Free the slots.\n   * \\param slots The slots to free.\n   */\n  void FreeSlots(const std::vector<int>& slots);\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<DraftTokenWorkspaceManagerObj>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO_FINAL(\"mlc.serve.DraftTokenWorkspaceManager\",\n                                    DraftTokenWorkspaceManagerObj, Object);\n\n private:\n  std::vector<int> free_slots_;\n  int max_num_tokens_;\n  int vocab_size_;\n  int hidden_size_;\n  DataType hidden_states_dtype_;\n  DLDevice device_;\n  const FunctionTable& ft_;\n  std::unordered_map<int, int> ref_count_;\n};\n\nclass DraftTokenWorkspaceManager : public ObjectRef {\n public:\n  DraftTokenWorkspaceManager(int max_num_tokens, int vocab_size, int hidden_size,\n                             DLDataType hidden_states_dtype, DLDevice device,\n                             const FunctionTable& ft) {\n    data_ = tvm::ffi::make_object<DraftTokenWorkspaceManagerObj>(\n        max_num_tokens, vocab_size, hidden_size, hidden_states_dtype, device, ft);\n  }\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(DraftTokenWorkspaceManager, ObjectRef,\n                                             DraftTokenWorkspaceManagerObj);\n};\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_DRAFT_TOKEN_WORKSPACE_MANAGER_H_\n"
  },
  {
    "path": "cpp/serve/engine.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine.cc\n * \\brief The implementation for runtime module of serving engine module in MLC LLM.\n */\n#include \"engine.h\"\n\n#include <dlpack/dlpack.h>\n#include <tvm/ffi/function.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/runtime/logging.h>\n#include <tvm/runtime/memory/memory_manager.h>\n#include <tvm/runtime/module.h>\n#include <tvm/runtime/nvtx.h>\n#include <tvm/runtime/threading_backend.h>\n#include <xgrammar/xgrammar.h>\n\n#include <cstdlib>\n#include <functional>\n#include <numeric>\n#include <optional>\n#include <tuple>\n#include <unordered_set>\n\n#include \"../support/json_parser.h\"\n#include \"../support/result.h\"\n#include \"../support/utils.h\"\n#include \"../tokenizers/tokenizers.h\"\n#include \"engine_actions/action.h\"\n#include \"engine_actions/action_commons.h\"\n#include \"engine_state.h\"\n#include \"event_trace_recorder.h\"\n#include \"logit_processor.h\"\n#include \"model.h\"\n#include \"request.h\"\n#include \"request_state.h\"\n#include \"sampler/sampler.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing tvm::Device;\nusing namespace tvm::runtime;\nusing tvm::ffi::Function;\n\nclass EngineModule;\n\n// get tokenizer info from model config\ninline std::optional<TokenizerInfo> GetTokenizerInfo(const tvm::ffi::json::Object& model_config) {\n  if (model_config.count(\"tokenizer_info\") == 0) {\n    LOG(WARNING) << \"Tokenizer info not found in mlc-chat-config.json. \"\n                 << \"Trying to automatically detect the tokenizer info\";\n    return std::nullopt;\n  }\n  const tvm::ffi::json::Object& tokenizer_info_obj =\n      model_config.at(\"tokenizer_info\").cast<tvm::ffi::json::Object>();\n  auto info = tvm::ffi::make_object<TokenizerInfoNode>();\n  if (tokenizer_info_obj.count(\"token_postproc_method\")) {\n    info->token_postproc_method =\n        tokenizer_info_obj.at(\"token_postproc_method\").cast<std::string>();\n  }\n  if (tokenizer_info_obj.count(\"prepend_space_in_encode\")) {\n    info->prepend_space_in_encode = tokenizer_info_obj.at(\"prepend_space_in_encode\").cast<bool>();\n  }\n  if (tokenizer_info_obj.count(\"strip_space_in_decode\")) {\n    info->strip_space_in_decode = tokenizer_info_obj.at(\"strip_space_in_decode\").cast<bool>();\n  }\n  return TokenizerInfo(info);\n}\n\ninline std::pair<std::optional<std::string>, int> GetEnvSocketHostPort() {\n  char* host_str = std::getenv(\"MLC_SOCKET_HOST\");\n  char* port_str = std::getenv(\"MLC_SOCKET_PORT\");\n  if (host_str == nullptr || port_str == nullptr) {\n    return {std::nullopt, -1};\n  }\n  std::string host(host_str);\n  if (host.empty()) {\n    return {std::nullopt, -1};\n  }\n  return {host, std::atoi(port_str)};\n}\n\n// string back error node\nvoid StreamBackErrorImpl(Request request, FRequestStreamCallback request_stream_callback,\n                         String finish_reason) {\n  // If the request input length exceeds the maximum allowed single sequence length,\n  // invoke callback and do not process the request.\n  Array<RequestStreamOutput> output{RequestStreamOutput(\n      request->id, std::vector<std::vector<int64_t>>(request->generation_cfg->n), std::nullopt,\n      std::vector<Optional<String>>(request->generation_cfg->n, finish_reason),\n      std::vector<String>(request->generation_cfg->n))};\n  // NOTE: Invariant requirement\n  // always stream back final usage\n  // otherwise frontend may have issues deciding\n  String dummy_usage = (\"{ \\\"prompt_tokens\\\": 0, \\\"completion_tokens\\\": 0, \\\"total_tokens\\\": 0 }\");\n  output.push_back(RequestStreamOutput::Usage(request->id, dummy_usage));\n  if (request_stream_callback != nullptr) {\n    request_stream_callback(output);\n  }\n}\n\nvoid AbortRequestImpl(EngineState estate, const Array<Model>& models, const String& request_id,\n                      String finish_reason) {\n  auto it_rstate = estate->request_states.find(request_id);\n  if (it_rstate == estate->request_states.end()) {\n    // The request to abort does not exist.\n    return;\n  }\n\n  RequestState rstate = it_rstate->second;\n  Request request = rstate->entries[0]->request;\n\n  // - Check if the request is running or pending.\n  auto it_running = std::find(estate->running_queue.begin(), estate->running_queue.end(), request);\n  auto it_waiting = std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), request);\n\n  estate->request_states.erase(request->id);\n  if (it_running != estate->running_queue.end()) {\n    // The request to abort is in running queue\n    estate->running_queue.erase(it_running);\n\n    for (int i = static_cast<int>(rstate->entries.size()) - 1; i >= 0; --i) {\n      if (estate->prefix_cache->HasSequence(rstate->entries[i]->mstates[0]->internal_id)) {\n        estate->prefix_cache->RecycleSequence(rstate->entries[i]->mstates[0]->internal_id,\n                                              /*lazy=*/false);\n      } else {\n        if (rstate->entries[i]->status != RequestStateStatus::kAlive) {\n          estate->id_manager.RecycleId(rstate->entries[i]->mstates[0]->internal_id);\n          continue;\n        }\n        RemoveRequestFromModel(estate, rstate->entries[i]->mstates[0]->internal_id, models);\n        estate->id_manager.RecycleId(rstate->entries[i]->mstates[0]->internal_id);\n      }\n    }\n  }\n  if (it_waiting != estate->waiting_queue.end()) {\n    // The request to abort is in waiting queue\n    estate->waiting_queue.erase(it_waiting);\n  }\n  // Todo: abortion when the request is not in either queue?\n\n  // Send a callback to notice the abortion.\n  StreamBackErrorImpl(request, estate->request_stream_callback_, finish_reason);\n  estate->running_rsentries_changed = true;\n}\n\n/*!\n *  \\brief This a mock engine that always echo back the inputs\n *   and attaches the generation config to usage.extra\n *\n * \\note: mock engine test cannot replace real engine test.\n *\n * It only tests that parameters are converted and\n * passed correctly to the backend.\n */\nclass MockEchoEngineImpl : public Engine {\n public:\n  static Result<EngineCreationOutput> Create(const std::string& engine_config_json_str,\n                                             FRequestStreamCallback request_stream_callback,\n                                             const tvm::ffi::json::Object& model_config) {\n    using TResult = Result<EngineCreationOutput>;\n    // set dummy values\n    InferrableEngineConfig inferrable_config;\n    inferrable_config.max_num_sequence = 32;\n    inferrable_config.max_total_sequence_length = 32 * 4096;\n    inferrable_config.max_single_sequence_length = 4096;\n    inferrable_config.prefill_chunk_size = 1024;\n    inferrable_config.max_history_size = 1024;\n    tvm::ffi::String err;\n    auto config_json = tvm::ffi::json::Parse(engine_config_json_str, &err);\n    if (!err.empty()) {\n      return TResult::Error(err);\n    }\n    EngineConfig engine_config = EngineConfig::FromJSONAndInferredConfig(\n        config_json.cast<tvm::ffi::json::Object>(), inferrable_config);\n\n    auto n = std::make_unique<MockEchoEngineImpl>();\n    n->request_stream_callback_ = request_stream_callback;\n    n->tokenizer_ = Tokenizer::FromPath(engine_config->model, GetTokenizerInfo(model_config));\n    // - Get the default generation config from the first model.\n    GenerationConfig default_generation_cfg =\n        GenerationConfig::GetDefaultFromModelConfig(model_config);\n    return TResult::Ok({std::move(n), std::move(engine_config), std::move(default_generation_cfg)});\n  }\n\n  void Reset() final {}\n\n  bool Empty() final { return request_map_.empty(); }\n\n  void SetRequestStreamCallback(FRequestStreamCallback request_stream_callback) final {\n    request_stream_callback_ = request_stream_callback;\n  }\n\n  FRequestStreamCallback GetRequestStreamCallback() final { return request_stream_callback_; }\n\n  void AddRequest(Request request) final {\n    // precompute the stream back results and store them in the request_map\n    request = Request::FromUntokenized(request, tokenizer_);\n    std::vector<RequestStreamOutput> outputs;\n    int64_t completion_tokens = 0;\n    int64_t prompt_tokens = 0;\n\n    for (Data input : request->inputs) {\n      // only stream back token data\n      if (auto* token_data = input.as<TokenDataNode>()) {\n        for (int64_t token_id : token_data->token_ids) {\n          prompt_tokens += 1;\n          completion_tokens += 1;\n          if (request->generation_cfg->max_tokens == -1 ||\n              completion_tokens <= request->generation_cfg->max_tokens) {\n            outputs.push_back(RequestStreamOutput(\n                request->id,\n                std::vector<std::vector<int64_t>>(request->generation_cfg->n, {token_id}),\n                std::nullopt,\n                std::vector<Optional<String>>(request->generation_cfg->n, std::nullopt),\n                std::vector<String>(request->generation_cfg->n)));\n          }\n        }\n      }\n    }\n\n    // output go beyond max tokens\n    String finish_reason = \"stop\";\n    if (request->generation_cfg->max_tokens != -1 &&\n        prompt_tokens > request->generation_cfg->max_tokens) {\n      finish_reason = \"length\";\n    }\n    std::vector<std::vector<int64_t>> group_delta_token_ids;\n\n    // correct the last output with right finish reason\n    if (outputs.size() > 0) {\n      group_delta_token_ids = outputs.back()->group_delta_token_ids;\n      outputs.pop_back();\n    }\n    outputs.push_back(RequestStreamOutput(\n        request->id, group_delta_token_ids, std::nullopt,\n        std::vector<Optional<String>>(request->generation_cfg->n, finish_reason),\n        std::vector<String>(request->generation_cfg->n)));\n\n    // attach usage and config\n    tvm::ffi::json::Object usage;\n    usage.Set(\"prompt_tokens\", static_cast<int64_t>(prompt_tokens));\n    usage.Set(\"completion_tokens\",\n              static_cast<int64_t>(completion_tokens * request->generation_cfg->n));\n    usage.Set(\"total_tokens\",\n              static_cast<int64_t>(prompt_tokens + completion_tokens * request->generation_cfg->n));\n    usage.Set(\"extra\", request->generation_cfg->AsJSON());\n    // NOTE: Invariant requirement\n    // always stream back final usage\n    // otherwise frontend may have issues deciding termination\n    outputs.push_back(RequestStreamOutput::Usage(request->id, tvm::ffi::json::Stringify(usage)));\n    // reverse the stream back so we can just pop back and get out\n    std::reverse(outputs.begin(), outputs.end());\n\n    request_map_[request->id] = MockRequestState{request, std::move(outputs)};\n  }\n\n  void AbortRequest(const String& request_id) {\n    auto it = request_map_.find(request_id);\n    if (it == request_map_.end()) return;\n    Request request = it->second.request;\n\n    // If the request input length exceeds the maximum allowed single sequence length,\n    // invoke callback and do not process the request.\n    Array<RequestStreamOutput> output{RequestStreamOutput(\n        request_id, std::vector<std::vector<int64_t>>(request->generation_cfg->n), std::nullopt,\n        std::vector<Optional<String>>(request->generation_cfg->n, String(\"abort\")),\n        std::vector<String>(request->generation_cfg->n))};\n    // NOTE: Invariant requirement\n    // always stream back final usage\n    // otherwise frontend may have issues deciding\n    String dummy_usage =\n        (\"{ \\\"prompt_tokens\\\": 0, \\\"completion_tokens\\\": 0, \\\"total_tokens\\\": 0 }\");\n    output.push_back(RequestStreamOutput::Usage(request->id, dummy_usage));\n    request_map_.erase(it);\n    if (request_stream_callback_ != nullptr) {\n      request_stream_callback_(output);\n    }\n  }\n\n  void AbortAllRequests() final {\n    // avoid deletion during iteraton\n    std::vector<String> request_ids;\n    for (const auto& kv : request_map_) {\n      request_ids.push_back(kv.first);\n    }\n    for (String req_id : request_ids) {\n      AbortRequest(req_id);\n    }\n  }\n\n  void Step() final {\n    Array<RequestStreamOutput> outputs;\n    std::vector<String> finished_request_ids;\n    for (auto& kv : request_map_) {\n      MockRequestState& state = kv.second;\n      TVM_FFI_ICHECK_GE(state.reversed_outputs.size(), 2);\n      if (state.reversed_outputs.size() == 2) {\n        outputs.push_back(state.reversed_outputs.back());\n        state.reversed_outputs.pop_back();\n        outputs.push_back(state.reversed_outputs.back());\n        finished_request_ids.push_back(kv.first);\n      } else {\n        outputs.push_back(state.reversed_outputs.back());\n        state.reversed_outputs.pop_back();\n      }\n    }\n    for (String req_id : finished_request_ids) {\n      request_map_.erase(req_id);\n    }\n    if (request_stream_callback_ != nullptr) {\n      request_stream_callback_(outputs);\n    }\n  }\n\n  /************** Debug/Profile **************/\n\n  /*! \\brief Internal engine metrics. */\n  String JSONMetrics() final { return \"{}\"; }\n\n  /*! \\brief Call the given global function on all workers. Only for debug purpose. */\n  void DebugCallFuncOnAllAllWorker(const String& func_name, Optional<String> func_args) final {}\n\n private:\n  struct MockRequestState {\n    Request request;\n    std::vector<RequestStreamOutput> reversed_outputs;\n  };\n\n  // internal tokenizer\n  // keep for future usage, in case we want to echo back the tokens\n  Tokenizer tokenizer_;\n  // callback stream\n  FRequestStreamCallback request_stream_callback_;\n  // active requests\n  std::unordered_map<String, MockRequestState> request_map_;\n};\n\n/********************** Engine Impl **********************/\n\n/*! \\brief The implementation of Engine. */\nclass EngineImpl : public Engine {\n  friend class EngineModule;\n\n public:\n  /********************** Engine Management **********************/\n\n  static Result<EngineCreationOutput> Create(const std::string& engine_config_json_str,\n                                             DLDevice device,\n                                             FRequestStreamCallback request_stream_callback,\n                                             Optional<EventTraceRecorder> trace_recorder) {\n    using TResult = Result<EngineCreationOutput>;\n    std::unique_ptr<EngineImpl> n = std::make_unique<EngineImpl>();\n\n    // - Read the models and model libs from the EngineConfig JSON string.\n    Result<std::vector<std::pair<std::string, std::string>>> models_and_model_libs_res =\n        EngineConfig::GetModelsAndModelLibsFromJSONString(engine_config_json_str);\n    if (models_and_model_libs_res.IsErr()) {\n      return TResult::Error(models_and_model_libs_res.UnwrapErr());\n    }\n    std::vector<std::pair<std::string, std::string>> models_and_model_libs =\n        models_and_model_libs_res.Unwrap();\n\n    int num_model = models_and_model_libs.size();\n    TVM_FFI_ICHECK_GE(num_model, 1);\n    // - Initialize singleton states inside the engine.\n    n->estate_->Reset();\n    n->estate_->request_stream_callback_ = std::move(request_stream_callback);\n    n->trace_recorder_ = trace_recorder;\n    n->device_ = device;\n    // - Load model config, create a shared disco session when tensor\n    // parallelism is enabled.\n    std::vector<std::string> model_libs;\n    std::vector<tvm::ffi::json::Object> model_configs;\n    model_libs.reserve(num_model);\n    model_configs.reserve(num_model);\n    for (int i = 0; i < num_model; ++i) {\n      const auto& [model_str, model_lib] = models_and_model_libs[i];\n      Result<tvm::ffi::json::Object> model_config_res = Model::LoadModelConfig(model_str);\n      if (model_config_res.IsErr()) {\n        return TResult::Error(\"Model \" + std::to_string(i) +\n                              \" has invalid mlc-chat-config.json: \" + model_config_res.UnwrapErr());\n      }\n      model_libs.push_back(model_lib);\n      model_configs.push_back(model_config_res.Unwrap());\n    }\n\n    // kick in mock path so we don't have to load in models\n    if (models_and_model_libs[0].second == \"mock://echo\") {\n      return MockEchoEngineImpl::Create(engine_config_json_str,\n                                        n->estate_->request_stream_callback_, model_configs[0]);\n    }\n\n    auto [session, num_shards, model_num_pipeline_stages] =\n        n->CreateDiscoSession(model_libs, model_configs, device);\n\n    // - Initialize each model independently.\n    n->models_.clear();\n    for (int i = 0; i < num_model; ++i) {\n      const auto& [model_str, model_lib] = models_and_model_libs[i];\n      Model model = Model::Create(model_lib, model_str, model_configs[i], device, session,\n                                  num_shards, model_num_pipeline_stages[i],\n                                  /*trace_enabled=*/trace_recorder.defined());\n      n->models_.push_back(model);\n    }\n    // - Initialize NVSHMEM\n    n->estate_->disaggregation = n->models_[0]->GetMetadata().disaggregation;\n    if (n->estate_->disaggregation) {\n      LOG(INFO) << \"Initializing NVSHMEM\";\n      char* nvshmem_init_config_json_char = std::getenv(\"MLC_NVSHMEM_INIT_CONFIG_JSON_STR\");\n      TVM_FFI_ICHECK(nvshmem_init_config_json_char != nullptr)\n          << \"The environment variables MLC_NVSHMEM_INIT_CONFIG_JSON_STR should be set.\";\n      std::string f_name = \"runtime.disco.nvshmem.init_nvshmem_wrapper\";\n      if (session != nullptr) {\n        n->DebugCallFuncOnAllAllWorker(f_name, String(nvshmem_init_config_json_char));\n      } else {\n        static Function func = Function::GetGlobalRequired(f_name);\n        func(String(nvshmem_init_config_json_char));\n      }\n      LOG(INFO) << \"NVSHMEM initialized successfully.\";\n    }\n\n    // - Automatically infer the missing fields in EngineConfig JSON strings\n    // and get the final EngineConfig.\n    Result<EngineConfig> engine_config_res =\n        n->AutoDecideEngineConfig(engine_config_json_str, model_configs);\n    if (engine_config_res.IsErr()) {\n      return TResult::Error(engine_config_res.UnwrapErr());\n    }\n    EngineConfig engine_config = engine_config_res.Unwrap();\n    {\n      if (engine_config->prefix_cache_mode == PrefixCacheMode::kRadix) {\n        n->estate_->prefix_cache = PrefixCache::CreateRadixPrefixCache(\n            static_cast<size_t>(engine_config->prefix_cache_max_num_recycling_seqs),\n            [engine_ptr = n.get()](int64_t seq_id) {\n              RemoveRequestFromModel(engine_ptr->estate_, seq_id, engine_ptr->models_);\n              engine_ptr->estate_->id_manager.RecycleId(seq_id);\n            });\n      } else if (engine_config->prefix_cache_mode == PrefixCacheMode::kDisable) {\n        n->estate_->prefix_cache = PrefixCache::CreateNoPrefixCache();\n      } else {\n        LOG(FATAL) << \"Unsupported prefix cache mode: \"\n                   << static_cast<int>(engine_config->prefix_cache_mode);\n      }\n      if (engine_config->speculative_mode != SpeculativeMode::kDisable &&\n          engine_config->prefill_mode == PrefillMode::kHybrid) {\n        engine_config->prefill_mode = PrefillMode::kChunked;\n        LOG(WARNING)\n            << \"Hybrid prefill mode fallbacks to chunked prefill, due to speculative mode is \"\n               \"enabled and not implemented with hybrid prefill yet.\";\n      }\n    }\n    // - Load model weights, create KV cache and workspace.\n    n->model_workspaces_.clear();\n    for (const Model& model : n->models_) {\n      model->LoadParams();\n      model->SetMaxNumSequence(engine_config->max_num_sequence);\n      model->SetPrefillChunkSize(engine_config->prefill_chunk_size);\n      model->CreateKVCache(engine_config->kv_cache_page_size, engine_config->max_num_sequence,\n                           engine_config->max_total_sequence_length,\n                           engine_config->prefill_chunk_size, engine_config->max_history_size);\n      n->model_workspaces_.push_back(\n          ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()});\n    }\n    // - Initialize tokenizer and grammar\n    n->tokenizer_ = Tokenizer::FromPath(engine_config->model, GetTokenizerInfo(model_configs[0]));\n    n->token_table_ = n->tokenizer_->PostProcessedTokenTable();\n    n->cached_grammar_compiler_ = xgrammar::CachedGrammarCompiler(n->token_table_);\n    // - Create the logit processor and sampler, and\n    // the DraftTokenWorkspaceManager for speculative decoding.\n    int max_num_tokens = engine_config->max_num_sequence;\n    DraftTokenWorkspaceManager draft_token_workspace_manager{nullptr};\n    if (engine_config->speculative_mode != SpeculativeMode::kDisable) {\n      // multiply max num_tokens by two so we can do ping-pong swaping during draft/verify process\n      draft_token_workspace_manager =\n          n->models_[0]->CreateDraftTokenWorkspaceManager(max_num_tokens * 2);\n      draft_token_workspace_manager->AllocWorkspace(\n          &n->model_workspaces_[0],\n          /*require_hidden_states=*/engine_config->speculative_mode == SpeculativeMode::kEagle);\n    }\n    LogitProcessor logit_processor =\n        n->models_[0]->CreateLogitProcessor(max_num_tokens, trace_recorder);\n    Sampler sampler = n->models_[0]->CreateSampler(\n        max_num_tokens, static_cast<int>(n->models_.size()), trace_recorder);\n    // - Initialize engine actions that represent state transitions.\n    if (engine_config->speculative_mode != SpeculativeMode::kDisable) {\n      n->estate_->spec_draft_length = engine_config->spec_draft_length;\n    }\n    n->actions_ =\n        CreateEngineActions(n->models_, engine_config, model_configs, n->model_workspaces_,\n                            logit_processor, sampler, draft_token_workspace_manager, n->tokenizer_,\n                            n->trace_recorder_, n->estate_->request_stream_callback_, device);\n    n->draft_token_workspace_manager_ = draft_token_workspace_manager;\n    // - Automatically set the threading backend max concurrency.\n    n->engine_config_ = engine_config;\n    n->SetThreadMaxConcurrency();\n    // - Get the default generation config from the first model.\n    GenerationConfig default_generation_cfg =\n        GenerationConfig::GetDefaultFromModelConfig(model_configs[0]);\n    return TResult::Ok({std::move(n), std::move(engine_config), std::move(default_generation_cfg)});\n  }\n\n  void Reset() final {\n    AbortAllRequests();\n    estate_->Reset();\n    for (Model model : models_) {\n      model->Reset();\n    }\n  }\n\n  bool Empty() final { return estate_->running_queue.empty() && estate_->waiting_queue.empty(); }\n\n  String JSONMetrics() final { return tvm::ffi::json::Stringify(estate_->metrics.AsJSON(), 2); }\n\n  FRequestStreamCallback GetRequestStreamCallback() final {\n    return estate_->request_stream_callback_;\n  }\n\n  void SetRequestStreamCallback(FRequestStreamCallback request_stream_callback) final {\n    estate_->request_stream_callback_ = std::move(request_stream_callback);\n  }\n\n  // string back error node\n  void StreamBackError(Request request, String finish_reason) {\n    StreamBackErrorImpl(request, estate_->request_stream_callback_, finish_reason);\n  }\n\n  /***************** High-level Request Management *****************/\n\n  void HandleSpecialRequests(Request request) {\n    auto special_request = request->generation_cfg->debug_config.special_request;\n    switch (special_request) {\n      case SpecialRequestKind::kQueryEngineMetrics: {\n        Array<RequestStreamOutput> output = {\n            RequestStreamOutput::Usage(request->id, estate_->metrics.AsUsageJSONStr())};\n        estate_->request_stream_callback_(output);\n        break;\n      }\n      default:\n        break;\n    }\n  }\n\n  /*!\n   * \\brief Handle the given disaggregation request.\n   * Return true if skipping the subsequent AddRequest process.\n   */\n  bool HandleDisaggRequest(Request request) {\n    DisaggConfig disagg_config = request->generation_cfg->debug_config.disagg_config;\n    DisaggRequestKind kind = disagg_config.kind;\n    if (kind == DisaggRequestKind::kPrepareReceive) {\n      // No-op.\n      return false;\n    } else if (kind == DisaggRequestKind::kRemoteSend) {\n      int input_length = 0;\n      for (Data input : request->inputs) {\n        input_length += input->GetLength();\n      }\n      // - Truncate the inputs to the desired prefill length (specified by \"end\").\n      int kv_window_begin = disagg_config.kv_window_begin.value_or(0);\n      int kv_window_end = disagg_config.kv_window_end.value_or(input_length);\n      TVM_FFI_ICHECK_GE(kv_window_begin, 0);\n      if (kv_window_end < 0) {\n        kv_window_end = input_length + kv_window_end;\n      }\n      TVM_FFI_ICHECK_LT(kv_window_end, input_length)\n          << \"Prefill the full input on the remote machine is not supported.\";\n      TVM_FFI_ICHECK_LT(kv_window_begin, kv_window_end)\n          << \"\\\"begin >= end\\\" is not supported by remote prefill\";\n      request->inputs = SplitData(request->inputs, input_length, kv_window_end).first;\n      // - Check the invariant: \"end - begin\" equals the expanded metadata length.\n      TVM_FFI_ICHECK_EQ(disagg_config.kv_append_metadata.size(), models_.size());\n      for (const IntTuple& compressed_kv_append_metadata : disagg_config.kv_append_metadata) {\n        TVM_FFI_ICHECK(!compressed_kv_append_metadata.empty());\n        int num_segments = compressed_kv_append_metadata[0];\n        TVM_FFI_ICHECK_EQ(compressed_kv_append_metadata.size(), num_segments * 2 + 1);\n        int transmission_length = 0;\n        for (int i = 0; i < num_segments; ++i) {\n          transmission_length += compressed_kv_append_metadata[i * 2 + 2];\n        }\n        TVM_FFI_ICHECK_EQ(transmission_length, kv_window_end - kv_window_begin);\n      }\n      // - Override the \"n\" in generation config to 1.\n      ObjectPtr<GenerationConfigNode> updated_generation_cfg =\n          tvm::ffi::make_object<GenerationConfigNode>(*request->generation_cfg.get());\n      updated_generation_cfg->n = 1;\n      request->generation_cfg = GenerationConfig(updated_generation_cfg);\n      return false;\n    } else if (kind == DisaggRequestKind::kStartGeneration) {\n      auto it_rstate = estate_->request_states.find(request->id);\n      TVM_FFI_ICHECK(it_rstate != estate_->request_states.end());\n      TVM_FFI_ICHECK(!it_rstate->second->entries.empty());\n      request = it_rstate->second->entries[0]->request;\n      TVM_FFI_ICHECK(request->generation_cfg->debug_config.disagg_config.kind ==\n                     DisaggRequestKind::kPrepareReceive);\n      int input_length = 0;\n      for (Data input : request->inputs) {\n        input_length += input->GetLength();\n      }\n      // - Truncate the inputs to the desired prefill length (specified by \"end\").\n      int kv_window_begin = disagg_config.kv_window_begin.value_or(0);\n      int kv_window_end = disagg_config.kv_window_end.value_or(input_length);\n      TVM_FFI_ICHECK_EQ(kv_window_end, input_length);\n      if (kv_window_begin < 0) {\n        kv_window_begin = input_length + kv_window_begin;\n      }\n      TVM_FFI_ICHECK_GE(kv_window_begin, 0);\n      TVM_FFI_ICHECK_LT(kv_window_begin, input_length);\n      // The request is not supposed to be in running queue nor waiting queue.\n      auto it_running =\n          std::find(estate_->running_queue.begin(), estate_->running_queue.end(), request);\n      auto it_waiting =\n          std::find(estate_->waiting_queue.begin(), estate_->waiting_queue.end(), request);\n      TVM_FFI_ICHECK(it_running == estate_->running_queue.end());\n      TVM_FFI_ICHECK(it_waiting == estate_->waiting_queue.end());\n\n      RequestState rstate = it_rstate->second;\n      ObjectPtr<GenerationConfigNode> updated_generation_cfg =\n          tvm::ffi::make_object<GenerationConfigNode>(*request->generation_cfg.get());\n      // - Split the input data into two parts at the position \"kv_window_begin\".\n      TVM_FFI_ICHECK(!request->inputs.empty());\n      auto [lhs_data, rhs_data] = SplitData(request->inputs, input_length, kv_window_begin);\n      if (input_length - kv_window_begin == 1 && request->generation_cfg->n == 1) {\n        // - Commit the last token id to the request states.\n        TVM_FFI_ICHECK_EQ(rhs_data.size(), 1);\n        const auto* token_data = rhs_data.back().as<TokenDataNode>();\n        TVM_FFI_ICHECK(token_data != nullptr);\n        TVM_FFI_ICHECK_EQ(token_data->GetLength(), 1);\n        SampleResult last_token;\n        last_token.sampled_token_id = {token_data->token_ids.back(), 1.0};\n        for (RequestModelState mstate : rstate->entries[0]->mstates) {\n          mstate->CommitToken(last_token);\n          TVM_FFI_ICHECK_EQ(mstate->committed_tokens.size(), 1);\n        }\n        // - Set \"next_callback_token_pos\" so that this token will not be streamed back to user.\n        rstate->entries[0]->next_callback_token_pos = 1;\n        // - Update the request input.\n        request->inputs = lhs_data;\n        // - Increment the max_tokens in generation config.\n        if (request->generation_cfg->max_tokens != -1) {\n          ++updated_generation_cfg->max_tokens;\n        }\n      } else {\n        // Since there are multiple tokens to prefill, we add the remaining inputs\n        // to the request's RequestModelStates for prefill.\n        for (RequestModelState mstate : rstate->entries[0]->mstates) {\n          mstate->inputs = rhs_data;\n        }\n        // Add to waiting queue for prefill.\n        estate_->waiting_queue.insert(estate_->waiting_queue.begin(), request);\n      }\n      estate_->running_queue.push_back(request);\n      // Erase the disaggregation request kind.\n      updated_generation_cfg->debug_config.disagg_config.kind = DisaggRequestKind::kNone;\n      request->generation_cfg = GenerationConfig(updated_generation_cfg);\n      estate_->running_rsentries_changed = true;\n      return true;\n    }\n    LOG(FATAL) << \"Cannot reach here\";\n    throw;\n  }\n\n  void AddRequest(Request request) final {\n    NVTXScopedRange nvtx_scope(\"Add request \" + request->id);\n    // special requests do not involve generation\n    if (request->generation_cfg->debug_config.special_request != SpecialRequestKind::kNone) {\n      this->HandleSpecialRequests(request);\n      return;\n    }\n\n    RECORD_EVENT(trace_recorder_, request->id, \"request added to engine\");\n    auto add_time_point = std::chrono::high_resolution_clock::now();\n\n    // Get a request copy where all text inputs are tokenized.\n    request = Request::FromUntokenized(request, tokenizer_);\n    TVM_FFI_ICHECK_NE(request->prompt_tokens, -1);\n\n    if (request->prompt_tokens >= engine_config_->max_single_sequence_length &&\n        estate_->request_stream_callback_ != nullptr) {\n      this->StreamBackError(request, \"length\");\n      return;\n    }\n\n    // Handle disaggregation requests.\n    if (request->generation_cfg->debug_config.disagg_config.kind != DisaggRequestKind::kNone) {\n      bool return_now = this->HandleDisaggRequest(request);\n      if (return_now) {\n        return;\n      }\n    }\n\n    // Append to the waiting queue and create the request state.\n    estate_->waiting_queue.push_back(request);\n\n    int n = request->generation_cfg->n;\n    int rng_seed = request->generation_cfg->seed;\n    auto compiled_grammar = GetGrammarFromResponseFormat(request->generation_cfg->response_format);\n\n    std::vector<RequestStateEntry> rsentries;\n    // Create the request state entry for the input.\n    rsentries.emplace_back(request, models_.size(), estate_->id_manager.GetNewId(), rng_seed,\n                           token_table_, compiled_grammar);\n    if (n > 1) {\n      // Then create a request state entry for each parallel generation branch.\n      // We add a offset to the rng seed so that to make generations different.\n      rsentries.reserve(n + 1);\n      rsentries[0]->child_indices.reserve(n);\n      for (int i = 0; i < n; ++i) {\n        rsentries[0]->child_indices.push_back(rsentries.size());\n        rsentries.emplace_back(request, models_.size(), estate_->id_manager.GetNewId(),\n                               rng_seed + i + 1, token_table_, compiled_grammar,\n                               /*parent_idx=*/0);\n      }\n    }\n    RequestState rstate = RequestState(std::move(rsentries), n, add_time_point);\n    for (const RequestStateEntry& rsentry : rstate->entries) {\n      // Set the back reference.\n      // note, we avoid cyclic reference and use raw ptr.\n      rsentry->rstate = rstate.operator->();\n    }\n    request->rstate = rstate.operator->();\n    estate_->request_states.emplace(request->id, rstate);\n  }\n\n  void AbortRequest(const String& request_id) final {\n    AbortRequestImpl(estate_, models_, request_id);\n  }\n\n  void AbortAllRequests() final {\n    // - Collect all the request ids.\n    std::vector<String> request_ids;\n    request_ids.reserve(estate_->request_states.size());\n    for (const auto& kv : estate_->request_states) {\n      request_ids.push_back(kv.first);\n    }\n    // - Abort all the requests.\n    for (const String& request_id : request_ids) {\n      AbortRequest(request_id);\n    }\n  }\n\n  /*********************** Engine Action ***********************/\n\n  void Step() final {\n    TVM_FFI_ICHECK(estate_->request_stream_callback_ != nullptr)\n        << \"The request stream callback is not set. Engine cannot execute.\";\n    for (EngineAction action : actions_) {\n      Array<Request> processed_requests;\n      {\n        NVTXScopedRange nvtx_scope(\"Action step\");\n        processed_requests = action->Step(estate_);\n      }\n      if (!processed_requests.empty()) {\n        ActionStepPostProcess(processed_requests, estate_, models_, tokenizer_,\n                              estate_->request_stream_callback_,\n                              engine_config_->max_single_sequence_length,\n                              draft_token_workspace_manager_, trace_recorder_);\n        return;\n      }\n    }\n    TVM_FFI_ICHECK(estate_->running_queue.empty())\n        << \"Internal assumption violated: It is expected that an engine step takes at least one \"\n           \"action (e.g. prefill, decode, etc.) but it does not.\";\n  }\n\n  /************** Utility Functions **************/\n  std::tuple<Optional<Session>, int, std::vector<int>> CreateDiscoSession(\n      const std::vector<std::string>& model_libs,\n      const std::vector<tvm::ffi::json::Object>& model_configs, Device device) {\n    const auto& base_model_config = model_configs[0];\n\n    auto f_get_num_shards_num_stages =\n        [&device](const std::string& model_lib,\n                  const tvm::ffi::json::Object& model_config) -> std::pair<int, int> {\n      if (!StartsWith(model_lib, \"system://\")) {\n        Module executable = ffi::Module::LoadFromFile(model_lib);\n        Optional<Function> fload_exec = executable->GetFunction(\"vm_load_executable\");\n        TVM_FFI_ICHECK(fload_exec.defined()) << \"TVM runtime cannot find vm_load_executable\";\n        Module local_vm = fload_exec.value()().cast<Module>();\n        local_vm->GetFunction(\"vm_initialization\")\n            .value()(static_cast<int>(device.device_type), device.device_id,\n                     static_cast<int>(tvm::runtime::memory::AllocatorType::kPooled),\n                     static_cast<int>(kDLCPU), 0,\n                     static_cast<int>(tvm::runtime::memory::AllocatorType::kPooled));\n        ModelMetadata metadata = ModelMetadata::FromModule(local_vm, std::move(model_config));\n        return {metadata.tensor_parallel_shards, metadata.pipeline_parallel_stages};\n      } else {\n        return {1, 1};\n      }\n    };\n\n    int num_shards = -1;\n    int max_num_stages = 1;\n    std::vector<int> model_num_pipeline_stages;\n    model_num_pipeline_stages.reserve(model_libs.size());\n    TVM_FFI_ICHECK_EQ(model_libs.size(), model_configs.size());\n    for (int i = 0; i < static_cast<int>(model_libs.size()); ++i) {\n      auto [model_num_shards, model_num_stages] =\n          f_get_num_shards_num_stages(model_libs[i], model_configs[i]);\n      model_num_pipeline_stages.push_back(model_num_stages);\n      max_num_stages = std::max(max_num_stages, model_num_stages);\n      if (i == 0) {\n        num_shards = model_num_shards;\n      } else {\n        TVM_FFI_ICHECK_EQ(model_num_shards, num_shards)\n            << \"Inconsistent tensor_parallel_shards values across models. Some model is compiled \"\n               \"with tensor_parallel_shards \"\n            << num_shards << \" and some other model is compiled with tensor_parallel_shards \"\n            << model_num_shards;\n      }\n    }\n\n    Optional<Session> session = std::nullopt;\n    int num_workers = num_shards * max_num_stages;\n    if (num_workers > 1) {\n#ifndef MLC_SINGLE_GPU_ONLY\n      constexpr const char* f_create_process_pool = \"runtime.disco.create_process_pool\";\n      if (!Function::GetGlobal(f_create_process_pool).has_value()) {\n        LOG(FATAL) << \"Cannot find process launcher `\" << f_create_process_pool << \"`. \"\n                   << \"Multi-GPU inference depends on MLC LLM Python API to launch process.\";\n      }\n      std::string ccl;\n      if (device.device_type == kDLCUDA) {\n        ccl = \"nccl\";\n      } else if (device.device_type == kDLROCM) {\n        ccl = \"rccl\";\n      } else {\n        LOG(FATAL) << \"ValueError: Multi-GPU on device \" << DLDeviceType2Str(device.device_type)\n                   << \" is not supported. Currently, only NCCL and RCCL are integrated.\";\n      }\n      std::vector<int64_t> device_ids(num_workers);\n      for (int i = 0; i < num_workers; ++i) {\n        // device.device_id is the start of the worker 0 of this model\n        device_ids[i] = device.device_id + i;\n      }\n      const std::string& green_text_begin = \"\\033[92m\";\n      const std::string& yellow_text_begin = \"\\033[93m\";\n      const std::string& colored_text_end = \"\\033[0m\";\n      auto [socket_host, socket_port] = GetEnvSocketHostPort();\n      if (max_num_stages > 1 && socket_host.has_value()) {\n        // Use SocketSession when pipeline parallelism enabled and socket host and port are set.\n        TVM_FFI_ICHECK_GT(socket_port, 0)\n            << \"Invalid MLC socket port \" << socket_port\n            << \". Please set a valid port value in environment variable \\\"MLC_SOCKET_PORT\\\".\";\n        LOG(INFO) << \"Creating MLC socket session with socket host \" << socket_host.value()\n                  << \" and port \" << socket_port;\n        LOG(INFO) << \"Please launch \" << green_text_begin << max_num_stages - 1 << colored_text_end\n                  << \" remote socket node(s) with the following command to proceed:\\n\\t\"\n                  << green_text_begin << \"python -m mlc_llm.cli.disco_remote_socket_session \"\n                  << (socket_host.value() == \"0.0.0.0\" ? \"<YOUR_NODE_IP>\" : socket_host.value())\n                  << \" \" << socket_port << \" \" << num_shards << colored_text_end;\n        static Function f_create_socket_sess =\n            Function::GetGlobalRequired(\"runtime.disco.SocketSession\");\n        Session sess =\n            f_create_socket_sess(max_num_stages, num_shards, /*num_groups=*/max_num_stages,\n                                 socket_host.value(), socket_port)\n                .cast<Session>();\n        session = std::move(sess);\n      } else {\n        if (max_num_stages > 1) {\n          LOG(INFO)\n              << yellow_text_begin\n              << \"Model is enabled with \\\"pipeline_parallel_stages\\\" but the socket host/port is \"\n                 \"not set. If you intend to run the model on multiple nodes, please set \"\n                 \"environment variable \\\"MLC_SOCKET_HOST\\\" and \\\"MLC_SOCKET_PORT\\\" and run again.\"\n              << colored_text_end;\n        }\n        // Use ProcessSession otherwise.\n        session = Session::ProcessSession(num_workers, max_num_stages, f_create_process_pool,\n                                          \"mlc_llm.cli.worker\");\n      }\n      session.value()->InitCCL(ccl, Shape(device_ids));\n#else\n      LOG(FATAL) << \"MLC_SINGLE_GPU_ONLY is specified. Multi-GPU is not enabled.\";\n#endif  // MLC_SINGLE_GPU_ONLY\n    }\n    return {session, num_shards, model_num_pipeline_stages};\n  }\n\n  /************** Debug/Profile **************/\n\n  void DebugCallFuncOnAllAllWorker(const String& func_name, Optional<String> func_args) final {\n    TVM_FFI_ICHECK(!models_.empty()) << \"There is no model running in Engine.\";\n    models_[0]->DebugCallFuncOnAllAllWorker(func_name, func_args);\n  }\n\n private:\n  Result<EngineConfig> AutoDecideEngineConfig(\n      const std::string& engine_config_json_str,\n      const std::vector<tvm::ffi::json::Object>& model_configs) {\n    using TResult = Result<EngineConfig>;\n    tvm::ffi::String err;\n    auto config_json = tvm::ffi::json::Parse(engine_config_json_str, &err);\n    if (!err.empty()) {\n      return TResult::Error(err);\n    }\n    tvm::ffi::json::Object config = config_json.cast<tvm::ffi::json::Object>();\n    ObjectPtr<EngineConfigNode> n = tvm::ffi::make_object<EngineConfigNode>();\n\n    // - Get the engine mode and maximum GPU utilization for inference.\n    EngineMode mode = EngineModeFromString(json::Lookup<std::string>(config, \"mode\"));\n    double gpu_memory_utilization =\n        json::LookupOrDefault<double>(config, \"gpu_memory_utilization\", n->gpu_memory_utilization);\n    bool verbose = json::LookupOrDefault<bool>(config, \"verbose\", n->verbose);\n\n    // - Get the config fields that can be automatically inferred.\n    std::optional<int64_t> max_num_sequence =\n        json::LookupOptional<int64_t>(config, \"max_num_sequence\");\n    std::optional<int64_t> max_total_sequence_length =\n        json::LookupOptional<int64_t>(config, \"max_total_sequence_length\");\n    std::optional<int64_t> max_single_sequence_length =\n        json::LookupOptional<int64_t>(config, \"max_single_sequence_length\");\n    std::optional<int64_t> prefill_chunk_size =\n        json::LookupOptional<int64_t>(config, \"prefill_chunk_size\");\n    std::optional<int64_t> max_history_size =\n        json::LookupOptional<int64_t>(config, \"max_history_size\");\n    std::optional<std::string> kv_state_kind_str =\n        json::LookupOptional<std::string>(config, \"kv_state_kind\");\n    InferrableEngineConfig inferrable_cfg{max_num_sequence, max_total_sequence_length,\n                                          max_single_sequence_length, prefill_chunk_size,\n                                          max_history_size};\n\n    // - Get the model metadata.\n    std::vector<ModelMetadata> model_metadata;\n    for (const Model& model : models_) {\n      model_metadata.push_back(model->GetMetadata());\n    }\n    // - Select from kv cache or RNN state.\n    Result<bool> use_kv_cache = ModelsUseKVCache(model_configs);\n    if (use_kv_cache.IsErr()) {\n      return TResult::Error(use_kv_cache.UnwrapErr());\n    }\n    Result<InferrableEngineConfig> inferrable_cfg_res;\n    if (use_kv_cache.Unwrap()) {\n      // - Infer configuration.\n      inferrable_cfg_res = InferrableEngineConfig::InferForKVCache(\n          mode, device_, gpu_memory_utilization, model_configs, model_metadata, inferrable_cfg,\n          verbose);\n    } else {\n      // - Infer configuration.\n      inferrable_cfg_res = InferrableEngineConfig::InferForRNNState(\n          mode, device_, gpu_memory_utilization, model_configs, model_metadata, inferrable_cfg,\n          verbose);\n    }\n\n    if (inferrable_cfg_res.IsErr()) {\n      return TResult::Error(inferrable_cfg_res.UnwrapErr());\n    }\n    inferrable_cfg = inferrable_cfg_res.Unwrap();\n    TVM_FFI_ICHECK(inferrable_cfg.max_num_sequence.has_value());\n    TVM_FFI_ICHECK(inferrable_cfg.max_total_sequence_length.has_value());\n    use_kv_cache = ModelsUseKVCache(model_configs);\n    if (use_kv_cache.Unwrap()) {\n      TVM_FFI_ICHECK(inferrable_cfg.max_single_sequence_length.has_value());\n    }\n    TVM_FFI_ICHECK(inferrable_cfg.prefill_chunk_size.has_value());\n    TVM_FFI_ICHECK(inferrable_cfg.max_history_size.has_value());\n    return TResult::Ok(EngineConfig::FromJSONAndInferredConfig(config, inferrable_cfg));\n  }\n\n  /*! \\brief Set the maximum threading backend concurrency. */\n  void SetThreadMaxConcurrency() {\n    int host_cpu_usage = 1;\n    for (Model model : models_) {\n      host_cpu_usage += model->EstimateHostCPURequirement();\n    }\n    if (host_cpu_usage > 1) {\n      int max_concurrency = tvm::runtime::threading::MaxConcurrency();\n      tvm::runtime::threading::SetMaxConcurrency(std::min(\n          std::max(max_concurrency - host_cpu_usage, 1), engine_config_->max_num_sequence));\n    }\n  }\n\n  /*! \\brief Create a grammar init context according to the response format. If the response format\n   * is not JSON, return std::nullopt. */\n  std::optional<xgrammar::CompiledGrammar> GetGrammarFromResponseFormat(\n      const ResponseFormat& response_format) {\n    if (response_format.type != \"json_object\") {\n      return std::nullopt;\n    } else if (!response_format.schema) {\n      return cached_grammar_compiler_.GetCompiledGrammarForJSON();\n    } else {\n      return cached_grammar_compiler_.GetCompiledGrammarForJSONSchema(\n          response_format.schema.value());\n    }\n  }\n\n  // Engine state, managing requests and request states.\n  EngineState estate_;\n  // Configurations and singletons\n  EngineConfig engine_config_;\n  // internal tokenizer\n  Tokenizer tokenizer_;\n  std::vector<std::string> token_table_;\n  // Cached grammar compiler for grammar matching.\n  xgrammar::CachedGrammarCompiler cached_grammar_compiler_;\n  // Models\n  Array<Model> models_;\n  // Device that the models run on.\n  Device device_;\n  // Workspace of each model.\n  std::vector<ModelWorkspace> model_workspaces_;\n  // Engine actions.\n  Array<EngineAction> actions_;\n  // Draft token workspace manager for speculative decoding.\n  Optional<DraftTokenWorkspaceManager> draft_token_workspace_manager_;\n  // Event trace recorder.\n  Optional<EventTraceRecorder> trace_recorder_;\n};\n\nResult<EngineCreationOutput> Engine::Create(const std::string& engine_config_json_str,\n                                            Device device,\n                                            FRequestStreamCallback request_stream_callback,\n                                            Optional<EventTraceRecorder> trace_recorder) {\n  return EngineImpl::Create(engine_config_json_str, device, request_stream_callback,\n                            std::move(trace_recorder));\n}\n\n/*! \\brief Clear global memory manager */\nvoid ClearGlobalMemoryManager() {\n  static const char* kFunc = \"vm.builtin.memory_manager.clear\";\n  static Function f = Function::GetGlobalRequired(kFunc);\n  f();\n}\n\nclass EngineModule : public ffi::ModuleObj {\n public:\n  TVM_MODULE_VTABLE_BEGIN(\"mlc.serve.engine\");\n  TVM_MODULE_VTABLE_ENTRY(\"init\", &EngineModule::Init);\n  TVM_MODULE_VTABLE_ENTRY(\"add_request\", &EngineModule::AddRequest);\n  TVM_MODULE_VTABLE_ENTRY(\"create_request\", &EngineModule::CreateRequest);\n  TVM_MODULE_VTABLE_ENTRY(\"abort_request\", &EngineModule::Abort);\n  TVM_MODULE_VTABLE_ENTRY(\"step\", &EngineModule::Step);\n  TVM_MODULE_VTABLE_ENTRY(\"reset\", &EngineModule::Reset);\n  TVM_MODULE_VTABLE_ENTRY(\"json_metrics\", &EngineModule::JSONMetrics);\n  TVM_MODULE_VTABLE_ENTRY(\"get_request_stream_callback\", &EngineModule::GetRequestStreamCallback);\n  TVM_MODULE_VTABLE_ENTRY(\"set_request_stream_callback\", &EngineModule::SetRequestStreamCallback);\n  TVM_MODULE_VTABLE_END();\n\n  /*! \\brief Initialize the engine with config and other fields. */\n  void Init(const std::string& engine_config_json_str, Device device,\n            FRequestStreamCallback request_stream_callback,\n            Optional<EventTraceRecorder> trace_recorder) {\n    Result<EngineCreationOutput> output_res = Engine::Create(\n        engine_config_json_str, device, request_stream_callback, std::move(trace_recorder));\n    TVM_FFI_ICHECK(output_res.IsOk()) << output_res.UnwrapErr();\n    EngineCreationOutput output = output_res.Unwrap();\n    this->engine_ = std::move(output.reloaded_engine);\n    this->default_generation_config_ = output.default_generation_cfg;\n  }\n  /*! \\brief Construct an EngineModule. */\n  static ffi::Module Create() { return ffi::Module(tvm::ffi::make_object<EngineModule>()); }\n  /*! \\brief Redirection to `Engine::AddRequest`. */\n  void AddRequest(Request request) { return GetEngine()->AddRequest(std::move(request)); }\n  /*! \\brief Redirection to `Engine::AbortRequest`. */\n  void Abort(const String& request_id) { return GetEngine()->AbortRequest(request_id); }\n  /*! \\brief Create request with given arguments and the engine default generation config. */\n  Request CreateRequest(String id, Array<Data> inputs, String generation_cfg_json_str) {\n    auto config = json::ParseToJSONObject(generation_cfg_json_str);\n    auto gen_config = GenerationConfig::FromJSON(config, default_generation_config_);\n    TVM_FFI_ICHECK(gen_config.IsOk()) << gen_config.UnwrapErr();\n    return Request(std::move(id), std::move(inputs), gen_config.Unwrap());\n  }\n  /*! \\brief Redirection to `Engine::Step`. */\n  void Step() { return GetEngine()->Step(); }\n  /*! \\brief Redirection to `Engine::GetRequestStreamCallback`. */\n  FRequestStreamCallback GetRequestStreamCallback() {\n    return GetEngine()->GetRequestStreamCallback();\n  }\n  /*! \\brief Redirection to `Engine::SetRequestStreamCallback` */\n  void SetRequestStreamCallback(FRequestStreamCallback request_stream_callback) {\n    GetEngine()->SetRequestStreamCallback(request_stream_callback);\n  }\n  /*! \\brief Redirection to `Engine::Reset`. */\n  void Reset() { return GetEngine()->Reset(); }\n\n  /*! \\brief Redirection to `Engine::JSONMetrics`. */\n  String JSONMetrics() { return GetEngine()->JSONMetrics(); }\n\n private:\n  Engine* GetEngine() {\n    TVM_FFI_ICHECK(engine_ != nullptr) << \"Engine is not initialized via init\";\n    return engine_.get();\n  }\n\n  std::unique_ptr<Engine> engine_ = nullptr;\n  GenerationConfig default_generation_config_;\n};\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  namespace refl = tvm::ffi::reflection;\n  refl::GlobalDef().def(\"mlc.serve.create_engine\", EngineModule::Create);\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/engine.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine.h\n * \\brief The header of serving engine in MLC LLM.\n */\n#ifndef MLC_LLM_SERVE_ENGINE_H_\n#define MLC_LLM_SERVE_ENGINE_H_\n\n#include \"data.h\"\n#include \"engine_state.h\"\n#include \"event_trace_recorder.h\"\n#include \"request.h\"\n#include \"request_state.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing namespace tvm::runtime;\n\nclass Engine;\n\n/*!\n * \\brief The output of engine creation, including the created engine and\n * the default generation config for requests.\n */\nstruct EngineCreationOutput {\n  std::unique_ptr<Engine> reloaded_engine;\n  EngineConfig completed_engine_config;\n  GenerationConfig default_generation_cfg;\n};\n\n/*!\n * \\brief The engine interface for request serving in MLC LLM.\n * The engine can run one or multiple LLM models internally for\n * text generation. Usually, when there are multiple models,\n * speculative inference will be activated, where the first model\n * (index 0) is the main \"large model\" that has better generation\n * quality, and all other models are \"small\" models that used for\n * speculation.\n * The engine receives requests from the \"AddRequest\" method. For\n * an given request, the engine will keep generating new tokens for\n * the request until finish (under certain criterion). After finish,\n * the engine will return the generation result through the callback\n * function provided by the request.\n * \\note For now only one model run in the engine is supported.\n * Multiple model support such as speculative inference will\n * be followed soon in the future.\n *\n * The public interface of Engine has the following three categories:\n * - engine management,\n * - high-level request management,\n * - engine \"step\" action.\n */\nclass Engine {\n public:\n  /********************** Engine Management **********************/\n  virtual ~Engine() = default;\n\n  /*!\n   * \\brief Create an engine in unique pointer.\n   * \\param engine_config_json_str The serialized JSON string of the engine config.\n   * \\param device The device where the run models.\n   * \\param request_stream_callback The request stream callback function to.\n   * \\param trace_recorder Event trace recorder for requests.\n   * \\return The created Engine in pointer, and the default generation config.\n   */\n  static Result<EngineCreationOutput> Create(const std::string& engine_config_json_str,\n                                             Device device,\n                                             FRequestStreamCallback request_stream_callback,\n                                             Optional<EventTraceRecorder> trace_recorder);\n\n  /*! \\brief Reset the engine, clean up all running data and metrics. */\n  virtual void Reset() = 0;\n\n  /*! \\brief Check if the engine has no request to process. */\n  virtual bool Empty() = 0;\n\n  /*! \\brief Get the request stream callback function of the engine. */\n  virtual FRequestStreamCallback GetRequestStreamCallback() = 0;\n\n  /*! \\brief Set the request stream callback function of the engine. */\n  virtual void SetRequestStreamCallback(FRequestStreamCallback request_stream_callback) = 0;\n\n  /***************** High-level Request Management *****************/\n\n  /*! \\brief Add a new request to the engine. */\n  virtual void AddRequest(Request request) = 0;\n\n  /*! \\brief Abort the input request (specified by id string) from engine. */\n  virtual void AbortRequest(const String& request_id) = 0;\n\n  /*! \\brief Abort all requests from the engine. */\n  virtual void AbortAllRequests() = 0;\n\n  /*********************** Engine Action ***********************/\n\n  /*!\n   * \\brief The main function that the engine takes a step of action.\n   * At each step, the engine may decide to\n   * - run prefill for one (or more) requests,\n   * - run one-step decode for the all existing requests\n   * ...\n   * In the end of certain actions (e.g., decode), the engine will\n   * check if any request has finished, and will return the\n   * generation results for those finished requests.\n   */\n  virtual void Step() = 0;\n\n  /************** Debug/Profile **************/\n\n  /*! \\brief Internal engine metrics. */\n  virtual String JSONMetrics() = 0;\n\n  /*! \\brief Call the given global function on all workers. Only for debug purpose. */\n  virtual void DebugCallFuncOnAllAllWorker(const String& func_name, Optional<String> func_args) = 0;\n};\n\nvoid AbortRequestImpl(EngineState estate, const Array<Model>& models, const String& request_id,\n                      String finish_reason = \"abort\");\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_ENGINE_H_\n"
  },
  {
    "path": "cpp/serve/engine_actions/action.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_actions/action.cc\n */\n\n#include \"action.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nTVM_FFI_STATIC_INIT_BLOCK() { EngineActionObj::RegisterReflection(); }\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/engine_actions/action.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_actions/action.h\n * \\brief The abstraction of actions (e.g., prefill/decode) that an\n * Engine can take at each time step.\n */\n#ifndef MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_\n#define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_\n\n#include \"../config.h\"\n#include \"../draft_token_workspace_manager.h\"\n#include \"../engine.h\"\n#include \"../engine_state.h\"\n#include \"../event_trace_recorder.h\"\n#include \"../model.h\"\n#include \"../sampler/sampler.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing namespace tvm::runtime;\n\n/*!\n * \\brief The abstraction of actions that an Engine can take at each time step.\n * The only core interface of an action is the `Step` function.\n * At high level, the Step function takes the current engine state\n * as input, invokes model functions (such as batched-prefill or\n * batched-decode), run sampler to sample new tokens, and update\n * the engine state.\n */\nclass EngineActionObj : public Object {\n public:\n  /*!\n   * \\brief The behavior of the engine action in a single step.\n   * \\param estate The engine state to be analyzed and updated.\n   * \\return The processed requests in this step.\n   */\n  virtual Array<Request> Step(EngineState estate) = 0;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<EngineActionObj>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO(\"mlc.serve.EngineAction\", EngineActionObj, Object);\n};\n\n/*!\n * \\brief Managed reference of EngineActionObj.\n * It declares the full list of supported actions.\n * \\sa EngineActionObj\n */\nclass EngineAction : public ObjectRef {\n public:\n  /*!\n   * \\brief Create the action that prefills requests in the `waiting_queue`\n   * of the engine state.\n   * \\param models The models to run prefill in.\n   * \\param logit_processor The logit processor.\n   * \\param sampler The sampler to sample new tokens.\n   * \\param model_workspaces The workspace of each model.\n   * \\param engine_config The engine config.\n   * \\param model_configs The config of each model.\n   * \\param trace_recorder The event trace recorder for requests.\n   * \\return The created action object.\n   */\n  static EngineAction NewRequestPrefill(Array<Model> models, LogitProcessor logit_processor,\n                                        Sampler sampler,\n                                        std::vector<ModelWorkspace> model_workspaces,\n                                        EngineConfig engine_config,\n                                        std::vector<tvm::ffi::json::Object> model_configs,\n                                        Optional<EventTraceRecorder> trace_recorder);\n  /*!\n   * \\brief Create the action that prefills requests in the `waiting_queue`\n   * of the engine state.\n   * \\param models The models to run prefill in.\n   * \\param logit_processor The logit processor.\n   * \\param sampler The sampler to sample new tokens.\n   * \\param model_workspaces The workspace of each model.\n   * \\param draft_token_workspace_manager The draft token workspace manager.\n   * \\param engine_config The engine config.\n   * \\param model_configs The config of each model.\n   * \\param trace_recorder The event trace recorder for requests.\n   * \\return The created action object.\n   */\n  static EngineAction EagleNewRequestPrefill(\n      Array<Model> models, LogitProcessor logit_processor, Sampler sampler,\n      std::vector<ModelWorkspace> model_workspaces,\n      DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config,\n      std::vector<tvm::ffi::json::Object> model_configs,\n      Optional<EventTraceRecorder> trace_recorder);\n  /*!\n   * \\brief Create the action that runs one-step decode for requests in the\n   * `running_queue` of engine state. Preempt low-priority requests\n   * accordingly when it is impossible to decode all the running requests.\n   * \\note The BatchDecode action **does not** take effect for speculative\n   * decoding scenarios where there are multiple models. For speculative\n   * decoding in the future, we will use other specific actions.\n   * \\param models The model to run decode in. When there are multiple\n   * models, the `Step` function of the created action will not take effect.\n   * \\param tokenizer The tokenizer of the engine.\n   * \\param sampler The sampler to sample new tokens.\n   * \\param engine_config The engine config.\n   * \\param trace_recorder The event trace recorder for requests.\n   * \\return The created action object.\n   */\n  static EngineAction BatchDecode(Array<Model> models, Tokenizer tokenizer,\n                                  LogitProcessor logit_processor, Sampler sampler,\n                                  EngineConfig engine_config,\n                                  Optional<EventTraceRecorder> trace_recorder);\n\n  /*!\n   * \\brief Create the action that runs one-step speculative draft proposal for\n   * requests in the `running_queue` of engine state. Preempt low-priority requests\n   * accordingly when it is impossible to decode all the running requests.\n   * \\param models The model to run decode in. When there are multiple\n   * models, the `Step` function of the created action will not take effect.\n   * \\param sampler The sampler to sample new tokens.\n   * \\param model_workspaces The workspace of each model.\n   * \\param draft_token_workspace_manager The draft token workspace manager.\n   * \\param engine_config The engine config.\n   * \\param trace_recorder The event trace recorder for requests.\n   * \\return The created action object.\n   */\n  static EngineAction BatchDraft(Array<Model> models, LogitProcessor logit_processor,\n                                 Sampler sampler, std::vector<ModelWorkspace> model_workspaces,\n                                 DraftTokenWorkspaceManager draft_token_workspace_manager,\n                                 EngineConfig engine_config,\n                                 Optional<EventTraceRecorder> trace_recorder);\n\n  /*!\n   * \\brief Create the action that runs one-step speculative draft proposal for\n   * requests in the `running_queue` of engine state. Preempt low-priority requests\n   * accordingly when it is impossible to decode all the running requests.\n   * \\param models The model to run decode in. When there are multiple\n   * models, the `Step` function of the created action will not take effect.\n   * \\param sampler The sampler to sample new tokens.\n   * \\param model_workspaces The workspace of each model.\n   * \\param draft_token_workspace_manager The draft token workspace manager.\n   * \\param engine_config The engine config.\n   * \\param trace_recorder The event trace recorder for requests.\n   * \\return The created action object.\n   */\n  static EngineAction EagleBatchDraft(Array<Model> models, LogitProcessor logit_processor,\n                                      Sampler sampler, std::vector<ModelWorkspace> model_workspaces,\n                                      DraftTokenWorkspaceManager draft_token_workspace_manager,\n                                      EngineConfig engine_config,\n                                      Optional<EventTraceRecorder> trace_recorder);\n\n  /*!\n   * \\brief Create the action that runs one-step speculative verification for requests in the\n   * `running_queue` of engine state. Preempt low-priority requests\n   * accordingly when it is impossible to decode all the running requests.\n   * \\param models The model to run decode in. When there are multiple\n   * models, the `Step` function of the created action will not take effect.\n   * \\param model_workspaces The workspace of each model.\n   * \\param draft_token_workspace_manager The draft token workspace manager.\n   * \\param sampler The sampler to sample new tokens.\n   * \\param engine_config The engine config.\n   * \\param trace_recorder The event trace recorder for requests.\n   * \\return The created action object.\n   */\n  static EngineAction BatchVerify(Array<Model> models, LogitProcessor logit_processor,\n                                  Sampler sampler, std::vector<ModelWorkspace> model_workspaces,\n                                  DraftTokenWorkspaceManager draft_token_workspace_manager,\n                                  EngineConfig engine_config,\n                                  Optional<EventTraceRecorder> trace_recorder);\n\n  /*!\n   * \\brief Create the action that runs one-step speculative verification for requests in the\n   * `running_queue` of engine state. Preempt low-priority requests\n   * accordingly when it is impossible to decode all the running requests.\n   * \\param models The model to run decode in. When there are multiple\n   * models, the `Step` function of the created action will not take effect.\n   * \\param sampler The sampler to sample new tokens.\n   * \\param model_workspaces The workspace of each model.\n   * \\param draft_token_workspace_manager The draft token workspace manager.\n   * \\param engine_config The engine config.\n   * \\param trace_recorder The event trace recorder for requests.\n   * \\return The created action object.\n   */\n  static EngineAction EagleBatchVerify(Array<Model> models, LogitProcessor logit_processor,\n                                       Sampler sampler,\n                                       std::vector<ModelWorkspace> model_workspaces,\n                                       DraftTokenWorkspaceManager draft_token_workspace_manager,\n                                       EngineConfig engine_config,\n                                       Optional<EventTraceRecorder> trace_recorder);\n  /*!\n   * \\brief Create the action that executes the jump-forward decoding to predict the next tokens\n   * according to the grammar constraint. Does nothing for the requests without grammar. The\n   * predicted tokens will be fed to the next BatchDecode action. Retokenization may happen when\n   * the predicted string breaks the tokenization boundary.\n   * \\param models The model to run decode in. When there are multiple\n   * models, the `Step` function of the created action will not take effect.\n   * \\param tokenizer The tokenizer of the engine.\n   * \\param trace_recorder The event trace recorder for requests.\n   * \\return The created action object.\n   */\n  static EngineAction BatchJumpForward(Array<Model> models, Tokenizer tokenizer,\n                                       Optional<EventTraceRecorder> trace_recorder);\n\n  /*!\n   * \\brief Create the action that first makes a decision on whether to run speculative\n   * decoding or normal mode batch decode, and then runs the selected actions.\n   * \\param spec_decode_actions The actions for speculative decoding.\n   * \\param batch_decode_actions The actions for normal mode batch decoding.\n   * \\param engine_config The engine config.\n   * \\return The created action object\n   */\n  static EngineAction AutoSpecDecode(std::vector<EngineAction> spec_decode_actions,\n                                     std::vector<EngineAction> batch_decode_actions,\n                                     EngineConfig engine_config);\n\n  /*!\n   * \\brief Create the action that runs the disaggregation preparation for prefill.\n   * \\param models The underlying models whose KV cache are to be updated.\n   * \\param engine_config The engine config.\n   * \\param model_configs The config of each model.\n   * \\param trace_recorder The event trace recorder for requests.\n   * \\param request_stream_callback The stream callback function to pass the prefill\n   * preparation result back, including the KV cache append metadata and the prefix\n   * matched length in the prefix cache.\n   * \\return The created action object.\n   */\n  static EngineAction DisaggPrepareReceive(Array<Model> models, EngineConfig engine_config,\n                                           std::vector<tvm::ffi::json::Object> model_configs,\n                                           Optional<EventTraceRecorder> trace_recorder,\n                                           FRequestStreamCallback request_stream_callback);\n\n  /*!\n   * \\brief Create the action that runs the prefill and sends KV data to remote instance.\n   * \\param models The underlying models whose KV cache are to be updated.\n   * \\param model_workspaces The workspace of each model.\n   * \\param engine_config The engine config.\n   * \\param model_configs The config of each model.\n   * \\param trace_recorder The event trace recorder for requests.\n   * \\param request_stream_callback The stream callback function to pass the prefill\n   * preparation result back, including the KV cache append metadata and the prefix\n   * matched length in the prefix cache.\n   * \\param device The device of the model for synchronization.\n   * \\return The created action object.\n   */\n  static EngineAction DisaggRemoteSend(Array<Model> models,\n                                       std::vector<ModelWorkspace> model_workspaces,\n                                       EngineConfig engine_config,\n                                       std::vector<tvm::ffi::json::Object> model_configs,\n                                       Optional<EventTraceRecorder> trace_recorder,\n                                       FRequestStreamCallback request_stream_callback,\n                                       Device device);\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(EngineAction, ObjectRef, EngineActionObj);\n};\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_H_\n"
  },
  {
    "path": "cpp/serve/engine_actions/action_commons.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_actions/action_commons.cc\n */\n\n#include \"action_commons.h\"\n\n#include <tvm/runtime/nvtx.h>\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nArray<EngineAction> CreateEngineActions(Array<Model> models, EngineConfig engine_config,\n                                        std::vector<tvm::ffi::json::Object> model_configs,\n                                        std::vector<ModelWorkspace> model_workspaces,\n                                        LogitProcessor logit_processor, Sampler sampler,\n                                        DraftTokenWorkspaceManager draft_token_workspace_manager,\n                                        Tokenizer tokenizer,\n                                        Optional<EventTraceRecorder> trace_recorder,\n                                        FRequestStreamCallback request_stream_callback,\n                                        Device device) {\n  Array<EngineAction> actions;\n  ModelMetadata model_metadata = models[0]->GetMetadata();\n  if (engine_config->speculative_mode != SpeculativeMode::kDisable) {\n    // Speculative decoding is only possible for more than one model.\n    TVM_FFI_ICHECK_GT(models.size(), 1U);\n    if (engine_config->speculative_mode == SpeculativeMode::kEagle) {\n      TVM_FFI_ICHECK_GT(engine_config->spec_draft_length, 0)\n          << \"The automatic spec decoding does not support Eagle mode as of now.\";\n      actions = {EngineAction::EagleNewRequestPrefill(models,                         //\n                                                      logit_processor,                //\n                                                      sampler,                        //\n                                                      model_workspaces,               //\n                                                      draft_token_workspace_manager,  //\n                                                      engine_config,                  //\n                                                      model_configs,                  //\n                                                      trace_recorder),\n                 EngineAction::EagleBatchDraft(models, logit_processor, sampler, model_workspaces,\n                                               draft_token_workspace_manager, engine_config,\n                                               trace_recorder),\n                 EngineAction::EagleBatchVerify(models, logit_processor, sampler, model_workspaces,\n                                                draft_token_workspace_manager, engine_config,\n                                                trace_recorder)};\n    } else if (engine_config->speculative_mode == SpeculativeMode::kMedusa) {\n      TVM_FFI_ICHECK_GT(engine_config->spec_draft_length, 0)\n          << \"The automatic spec decoding does not support Eagle mode as of now.\";\n      actions = {EngineAction::EagleNewRequestPrefill(models,                         //\n                                                      logit_processor,                //\n                                                      sampler,                        //\n                                                      model_workspaces,               //\n                                                      draft_token_workspace_manager,  //\n                                                      engine_config,                  //\n                                                      model_configs,                  //\n                                                      trace_recorder),\n                 EngineAction::EagleBatchVerify(models, logit_processor, sampler, model_workspaces,\n                                                draft_token_workspace_manager, engine_config,\n                                                trace_recorder)};\n    } else if (engine_config->spec_draft_length > 0) {\n      // The \"small draft\" mode speculative decoding.\n      // If \"engine_config->spec_draft_length\" > 0, it means the draft length is\n      // configured to be a fixed value.\n      actions = {\n          EngineAction::NewRequestPrefill(models,            //\n                                          logit_processor,   //\n                                          sampler,           //\n                                          model_workspaces,  //\n                                          engine_config,     //\n                                          model_configs,     //\n                                          trace_recorder),\n          EngineAction::BatchDraft(models, logit_processor, sampler, model_workspaces,\n                                   draft_token_workspace_manager, engine_config, trace_recorder),\n          EngineAction::BatchVerify(models, logit_processor, sampler, model_workspaces,\n                                    draft_token_workspace_manager, engine_config, trace_recorder)};\n    } else {\n      // The \"small draft\" mode speculative decoding.\n      // \"engine_config->spec_draft_length\" being 0 means we want to enable\n      // automatic speculative decoding, which decides the spec decoding draft length\n      // automatically.\n      actions = {EngineAction::NewRequestPrefill(models,            //\n                                                 logit_processor,   //\n                                                 sampler,           //\n                                                 model_workspaces,  //\n                                                 engine_config,     //\n                                                 model_configs,     //\n                                                 trace_recorder),\n                 EngineAction::AutoSpecDecode(\n                     /*spec_decode_actions=*/{EngineAction::BatchDraft(\n                                                  models, logit_processor, sampler,\n                                                  model_workspaces, draft_token_workspace_manager,\n                                                  engine_config, trace_recorder),\n                                              EngineAction::BatchVerify(\n                                                  models, logit_processor, sampler,\n                                                  model_workspaces, draft_token_workspace_manager,\n                                                  engine_config, trace_recorder)},\n                     /*batch_decode_actions=*/\n                     {EngineAction::BatchDecode(models, tokenizer, logit_processor, sampler,\n                                                engine_config, trace_recorder)},\n                     engine_config)};\n    }\n  } else if (model_metadata.disaggregation) {\n    actions = {EngineAction::NewRequestPrefill(models,            //\n                                               logit_processor,   //\n                                               sampler,           //\n                                               model_workspaces,  //\n                                               engine_config,     //\n                                               model_configs,     //\n                                               trace_recorder),\n               EngineAction::BatchDecode(models, tokenizer, logit_processor, sampler, engine_config,\n                                         trace_recorder)};\n  } else {\n    // The normal mode.\n    actions = {EngineAction::NewRequestPrefill(models,            //\n                                               logit_processor,   //\n                                               sampler,           //\n                                               model_workspaces,  //\n                                               engine_config,     //\n                                               model_configs,     //\n                                               trace_recorder),\n               EngineAction::BatchJumpForward(models, tokenizer, trace_recorder),\n               EngineAction::BatchDecode(models, tokenizer, logit_processor, sampler, engine_config,\n                                         trace_recorder)};\n  }\n\n  if (model_metadata.disaggregation) {\n    // Insert the disaggregation actions.\n    Array<EngineAction> disaggregation_actions = {\n        EngineAction::DisaggPrepareReceive(models, engine_config, model_configs, trace_recorder,\n                                           request_stream_callback),\n        EngineAction::DisaggRemoteSend(models, model_workspaces, engine_config, model_configs,\n                                       trace_recorder, request_stream_callback, device)};\n    actions.insert(actions.begin(), disaggregation_actions.begin(), disaggregation_actions.end());\n  }\n  return actions;\n}\n\nvoid RemoveRequestFromModel(EngineState estate, int64_t req_internal_id,\n                            const Array<Model>& models) {\n  // Remove the request from all models (usually the KV cache).\n  for (Model model : models) {\n    model->RemoveSequence(req_internal_id);\n  }\n}\n\n/*!\n * \\brief Remove the given request state entry.\n * \\param estate The engine state to update after removal.\n * \\param models The models to remove the given request from.\n * \\param rsentry The request state entry to remove.\n */\nvoid RemoveRequestStateEntry(EngineState estate, const Array<Model>& models,\n                             RequestStateEntry rsentry,\n                             Optional<DraftTokenWorkspaceManager> draft_token_workspace_manager) {\n  if (draft_token_workspace_manager.defined()) {\n    std::vector<int> draft_token_slots;\n    for (const RequestModelState& mstate : rsentry->mstates) {\n      mstate->RemoveAllDraftTokens(&draft_token_slots);\n      draft_token_workspace_manager.value()->FreeSlots(draft_token_slots);\n    }\n  }\n  if (estate->prefix_cache->HasSequence(rsentry->mstates[0]->internal_id)) {\n    // If the sequence is stored in prefix cache, call prefix cache to remove.\n    if (!(rsentry->request->generation_cfg->debug_config.pinned_system_prompt)) {\n      // If the request is not pinned, recycle the request.\n      estate->prefix_cache->RecycleSequence(rsentry->mstates[0]->internal_id, /*lazy=*/true);\n    }\n    // If the request is pinned, do nothing over the prefix cache and KVCache.\n  } else {\n    // If the sequence is not stored in prefix cache, remove it directly.\n    RemoveRequestFromModel(estate, rsentry->mstates[0]->internal_id, models);\n    estate->id_manager.RecycleId(rsentry->mstates[0]->internal_id);\n  }\n}\n\nvoid ProcessFinishedRequestStateEntries(\n    const std::vector<RequestStateEntry>& finished_rsentries, EngineState estate,\n    const Array<Model>& models, int max_single_sequence_length,\n    Optional<DraftTokenWorkspaceManager> draft_token_workspace_manager,\n    Array<RequestStreamOutput>* callback_delta_outputs) {\n  NVTXScopedRange nvtx_scope(\"Process finished requests\");\n  // - Remove the finished request state entries.\n  for (const RequestStateEntry& rsentry : finished_rsentries) {\n    // The finished entry must be a leaf.\n    TVM_FFI_ICHECK(rsentry->child_indices.empty());\n    // Mark the status of this entry as finished.\n    rsentry->status = RequestStateStatus::kFinished;\n    // Remove the request state entry from all the models.\n    RemoveRequestStateEntry(estate, models, rsentry, draft_token_workspace_manager);\n\n    RequestState rstate = estate->GetRequestState(rsentry->request);\n    int parent_idx = rsentry->parent_idx;\n    while (parent_idx != -1) {\n      bool all_children_finished = true;\n      for (int child_idx : rstate->entries[parent_idx]->child_indices) {\n        if (rstate->entries[child_idx]->status != RequestStateStatus::kFinished) {\n          all_children_finished = false;\n          break;\n        }\n      }\n      if (!all_children_finished) {\n        break;\n      }\n\n      // All the children of the parent request state entry have finished.\n      // So we mark the parent entry as finished.\n      rstate->entries[parent_idx]->status = RequestStateStatus::kFinished;\n      // Remove the request state entry from all the models.\n\n      RemoveRequestStateEntry(estate, models, rstate->entries[parent_idx],\n                              draft_token_workspace_manager);\n      // Climb up to the parent.\n      parent_idx = rstate->entries[parent_idx]->parent_idx;\n    }\n\n    if (parent_idx == -1) {\n      // Remove from running queue and engine state.\n      auto it =\n          std::find(estate->running_queue.begin(), estate->running_queue.end(), rsentry->request);\n      TVM_FFI_ICHECK(it != estate->running_queue.end());\n      estate->running_queue.erase(it);\n      estate->request_states.erase(rsentry->request->id);\n\n      // Update engine metrics.\n      const RequestStateEntry& root_rsentry = rstate->entries[0];\n      auto trequest_finish = std::chrono::high_resolution_clock::now();\n\n      rstate->metrics.finish_time_point = trequest_finish;\n      estate->metrics.RequestFinishUpdate(rstate->metrics);\n\n      // always stream back usage in backend\n      callback_delta_outputs->push_back(RequestStreamOutput::Usage(\n          root_rsentry->request->id, rstate->metrics.AsUsageJSONStr(true)));\n    }\n    estate->running_rsentries_changed = true;\n  }\n}\n\nvoid ActionStepPostProcess(Array<Request> requests, EngineState estate, const Array<Model>& models,\n                           const Tokenizer& tokenizer,\n                           FRequestStreamCallback request_stream_callback,\n                           int64_t max_single_sequence_length,\n                           Optional<DraftTokenWorkspaceManager> draft_token_workspace_manager,\n                           Optional<EventTraceRecorder> trace_recorder) {\n  NVTXScopedRange nvtx_scope(\"EngineAction postproc\");\n  int num_requests = requests.size();\n  estate->postproc_workspace.finished_rsentries.clear();\n  estate->postproc_workspace.callback_delta_outputs.clear();\n  estate->postproc_workspace.finished_rsentries.reserve(num_requests);\n  estate->postproc_workspace.callback_delta_outputs.reserve(num_requests * 2);\n\n  // - Collect new generated tokens and finish reasons for requests.\n  for (int r = 0; r < num_requests; ++r) {\n    Request request = requests[r];\n    int n = request->generation_cfg->n;\n    RequestState rstate = estate->GetRequestState(requests[r]);\n\n    bool invoke_callback = false;\n    RequestStreamOutput stream_output = rstate->postproc_states.GetStreamOutput();\n    for (int i = 0; i < n; ++i) {\n      const RequestStateEntry& rsentry = n == 1 ? rstate->entries[0] : rstate->entries[i + 1];\n      rsentry->GetDeltaRequestReturn(tokenizer, max_single_sequence_length, &stream_output, i);\n      if (stream_output->group_finish_reason[i].has_value()) {\n        invoke_callback = true;\n        estate->postproc_workspace.finished_rsentries.push_back(rsentry);\n      }\n      if (!stream_output->group_delta_token_ids[i].empty() ||\n          !stream_output->group_extra_prefix_string[i].empty()) {\n        invoke_callback = true;\n      }\n    }\n\n    if (invoke_callback) {\n      stream_output->unpacked = false;\n      estate->postproc_workspace.callback_delta_outputs.push_back(std::move(stream_output));\n    }\n\n    // Update prefix cache and metrics.\n    for (const RequestStateEntry& rsentry : rstate->entries) {\n      std::vector<int32_t>& token_ids = rsentry->token_ids_for_prefix_cache_update;\n      token_ids.clear();\n      if (!rsentry->mstates[0]->prefilled_inputs.empty()) {\n        // Notify the prefix cache of the newly prefilled data.\n        for (const Data& data : rsentry->mstates[0]->prefilled_inputs) {\n          const TokenDataNode* token_data = data.as<TokenDataNode>();\n          if (token_data == nullptr) continue;\n          token_ids.insert(token_ids.end(), token_data->token_ids->data,\n                           token_data->token_ids->data + token_data->token_ids.size());\n          // note that we are counting prefill tokens across all branches\n          rstate->metrics.prefill_tokens += data->GetLength();\n        }\n        rsentry->mstates[0]->prefilled_inputs.clear();\n      }\n      int64_t num_committed_tokens = rsentry->mstates[0]->committed_tokens.size();\n      if (rsentry->mstates[0]->cached_committed_tokens < num_committed_tokens - 1) {\n        // Notify the prefix cache of the newly decoded data, except the last token as it is not\n        // in KVCache yet.\n        for (int64_t& i = rsentry->mstates[0]->cached_committed_tokens;\n             i < num_committed_tokens - 1; ++i) {\n          token_ids.push_back(rsentry->mstates[0]->committed_tokens[i].sampled_token_id.first);\n        }\n      }\n      if (!token_ids.empty()) {\n        estate->prefix_cache->ExtendSequence(rsentry->mstates[0]->internal_id, token_ids);\n      }\n    }\n\n    // - For all disaggregation requests with \"remote_send\",\n    // if it does not appear in the waiting queue, it means the prefill has been finished.\n    // In this case, we mark the request as finished.\n    if (request->generation_cfg->debug_config.disagg_config.kind ==\n        DisaggRequestKind::kRemoteSend) {\n      auto it = std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), request);\n      if (it == estate->waiting_queue.end()) {\n        TVM_FFI_ICHECK_EQ(rstate->entries.size(), 1);\n        estate->postproc_workspace.finished_rsentries.push_back(rstate->entries[0]);\n      }\n    }\n  }\n\n  ProcessFinishedRequestStateEntries(estate->postproc_workspace.finished_rsentries, estate, models,\n                                     max_single_sequence_length, draft_token_workspace_manager,\n                                     &estate->postproc_workspace.callback_delta_outputs);\n\n  if (!estate->postproc_workspace.callback_delta_outputs.empty()) {\n    NVTXScopedRange nvtx_scope(\"Call request stream callback\");\n    // - Invoke the stream callback function once for all collected requests.\n    request_stream_callback(estate->postproc_workspace.callback_delta_outputs);\n  }\n}  // namespace serve\n\nRequestStateEntry PreemptLastRunningRequestStateEntry(\n    EngineState estate, const Array<Model>& models,\n    Optional<DraftTokenWorkspaceManager> draft_token_workspace_manager,\n    Optional<EventTraceRecorder> trace_recorder) {\n  TVM_FFI_ICHECK(!estate->running_queue.empty());\n  Request request = estate->running_queue.back();\n\n  // Find the last alive request state entry, which is what we want to preempt.\n  RequestState rstate = estate->GetRequestState(request);\n  int preempt_rstate_idx = -1;\n  for (int i = static_cast<int>(rstate->entries.size()) - 1; i >= 0; --i) {\n    if (rstate->entries[i]->status == RequestStateStatus::kAlive) {\n      preempt_rstate_idx = i;\n      break;\n    }\n  }\n  TVM_FFI_ICHECK_NE(preempt_rstate_idx, -1);\n  RequestStateEntry rsentry = rstate->entries[preempt_rstate_idx];\n  if (estate->disaggregation) {\n    AbortRequestImpl(estate, models, request->id, \"preempt\");\n    return rsentry;\n  }\n  // When the request state entry still has pending inputs,\n  // it means the request is still in the waiting queue.\n  bool partially_alive = !rsentry->mstates[0]->inputs.empty();\n\n  // Remove from models.\n  // - Clear model speculation draft.\n  // - Update `inputs` for future prefill.\n  RECORD_EVENT(trace_recorder, rsentry->request->id, \"preempt\");\n  rsentry->status = RequestStateStatus::kPending;\n  std::vector<int> draft_token_slots;\n  for (RequestModelState mstate : rsentry->mstates) {\n    if (draft_token_workspace_manager.defined()) {\n      mstate->RemoveAllDraftTokens(&draft_token_slots);\n      draft_token_workspace_manager.value()->FreeSlots(draft_token_slots);\n    }\n\n    // If the commited tokens of the current model lags behind the\n    // committed tokens of the main model (models[0]), we commit those\n    // new tokens to this model.\n    for (size_t i = mstate->committed_tokens.size();\n         i < rsentry->mstates[0]->committed_tokens.size(); ++i) {\n      mstate->CommitToken(rsentry->mstates[0]->committed_tokens[i]);\n    }\n\n    std::vector<int32_t> committed_token_ids;\n    committed_token_ids.reserve(mstate->committed_tokens.size());\n    for (const SampleResult& committed_token : mstate->committed_tokens) {\n      committed_token_ids.push_back(committed_token.GetTokenId());\n    }\n    mstate->num_prefilled_tokens = 0;\n\n    Array<Data> inputs;\n    if (rsentry->parent_idx == -1) {\n      inputs = request->inputs;\n      if (const auto* token_input = inputs.back().as<TokenDataNode>()) {\n        // Merge the TokenData so that a single time TokenEmbed is needed.\n        std::vector<int> token_ids{token_input->token_ids->data,\n                                   token_input->token_ids->data + token_input->token_ids.size()};\n        token_ids.insert(token_ids.end(), committed_token_ids.begin(), committed_token_ids.end());\n        inputs.Set(static_cast<int64_t>(inputs.size()) - 1, TokenData(token_ids));\n      } else if (!committed_token_ids.empty()) {\n        inputs.push_back(TokenData(committed_token_ids));\n      }\n    } else if (!committed_token_ids.empty()) {\n      inputs.push_back(TokenData(committed_token_ids));\n    }\n    mstate->inputs = std::move(inputs);\n    mstate->prefilled_inputs.clear();\n    mstate->cached_committed_tokens = 0;\n    mstate->num_tokens_for_next_decode = 0;\n  }\n  if (estate->prefix_cache->HasSequence(rsentry->mstates[0]->internal_id)) {\n    estate->prefix_cache->RecycleSequence(rsentry->mstates[0]->internal_id, /*lazy=*/false);\n  } else {\n    RemoveRequestFromModel(estate, rsentry->mstates[0]->internal_id, models);\n  }\n  // Since the sequence has been removed from model, assign a new sequence ID.\n  int64_t new_seq_id = estate->id_manager.GetNewId();\n  for (RequestModelState mstate : rsentry->mstates) {\n    mstate->internal_id = new_seq_id;\n  }\n\n  if (preempt_rstate_idx == 0) {\n    // Remove from running queue.\n    estate->running_queue.erase(estate->running_queue.end() - 1);\n  }\n  if (!partially_alive && preempt_rstate_idx == static_cast<int>(rstate->entries.size()) - 1) {\n    // Add to the front of waiting queue.\n    estate->waiting_queue.insert(estate->waiting_queue.begin(), request);\n  }\n  estate->running_rsentries_changed = true;\n  return rsentry;\n}\n\nstd::pair<Tensor, std::vector<SampleResult>> ApplyLogitProcessorAndSample(\n    const LogitProcessor& logit_processor, const Sampler& sampler, const Tensor& logits,\n    const Array<GenerationConfig>& generation_cfg, const Array<String>& request_ids,\n    const Array<RequestModelState>& mstates, const std::vector<RandomGenerator*>& rngs,\n    const std::vector<int>& sample_indices, const Array<GenerationConfig>& child_generation_cfg,\n    const Array<String>& child_request_ids, const std::vector<int>& child_sample_indices) {\n  // - Update logits.\n  logit_processor->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids);\n\n  // - Compute probability distributions.\n  Tensor probs_on_device =\n      logit_processor->ComputeProbsFromLogits(logits, generation_cfg, request_ids);\n\n  // - Sample tokens.\n  Tensor renormalized_probs = sampler->BatchRenormalizeProbsByTopP(probs_on_device, sample_indices,\n                                                                   request_ids, generation_cfg);\n  std::vector<SampleResult> sample_results = sampler->BatchSampleTokensWithProbAfterTopP(\n      renormalized_probs, child_sample_indices, child_request_ids, child_generation_cfg, rngs);\n  return {std::move(probs_on_device), std::move(sample_results)};\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/engine_actions/action_commons.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_actions/action_commons.h\n * \\brief Common functions that may be used in multiple EngineActions.\n */\n#ifndef MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_COMMONS_H_\n#define MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_COMMONS_H_\n\n#include <tvm/ffi/container/array.h>\n\n#include \"../../tokenizers/tokenizers.h\"\n#include \"../draft_token_workspace_manager.h\"\n#include \"../engine.h\"\n#include \"../engine_state.h\"\n#include \"../event_trace_recorder.h\"\n#include \"../model.h\"\n#include \"action.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing namespace tvm::runtime;\n\n/*! \\brief Create the engine actions based on engine config. */\nArray<EngineAction> CreateEngineActions(Array<Model> models, EngineConfig engine_config,\n                                        std::vector<tvm::ffi::json::Object> model_configs,\n                                        std::vector<ModelWorkspace> model_workspaces,\n                                        LogitProcessor logit_processor, Sampler sampler,\n                                        DraftTokenWorkspaceManager draft_token_workspace_manager,\n                                        Tokenizer tokenizer,\n                                        Optional<EventTraceRecorder> trace_recorder,\n                                        FRequestStreamCallback request_stream_callback,\n                                        Device device);\n\n/*!\n * \\brief Remove the given request from models.\n * \\param estate The engine state to update after removal.\n * \\param req_internal_id The internal id of the request to remove.\n * \\param models The models to remove the given request from.\n */\nvoid RemoveRequestFromModel(EngineState estate, int64_t req_internal_id,\n                            const Array<Model>& models);\n\n/*!\n * \\brief The request post-processing after an engine action step.\n * It includes\n * - invoke the request function callback to return new generated tokens,\n * - update the engine state for finished requests.\n * \\note This function may remove requests from the `running_queue`.\n * \\param requests The requests to process.\n * \\param estate The engine state.\n * \\param models The models to remove the finished from.\n * \\param tokenizer The tokenizer for logprob process.\n * \\param request_stream_callback The request stream callback function.\n * \\param max_single_sequence_length The max single sequence length to help decide\n * \\param draft_token_workspace_manager The draft token workspace manager.\n * \\param trace_recorder The event trace recorder for requests.\n * if a request is finished.\n */\nvoid ActionStepPostProcess(Array<Request> requests, EngineState estate, const Array<Model>& models,\n                           const Tokenizer& tokenizer,\n                           FRequestStreamCallback request_stream_callback,\n                           int64_t max_single_sequence_length,\n                           Optional<DraftTokenWorkspaceManager> draft_token_workspace_manager,\n                           Optional<EventTraceRecorder> trace_recorder);\n\n/*!\n * \\brief Preempt the last running request state entry from `running_queue`.\n * If all entries of the selected request have been preempted,\n * remove it from running request.\n * If it is not in the waiting request queue, add it to the waiting queue.\n * \\param estate The engine state to update due to preemption.\n * \\param models The models to remove preempted requests from.\n * \\param draft_token_workspace_manager The draft token workspace manager for requests. Must be\n * provided if speculative decoding is enabled.\n * \\param trace_recorder The event trace recorder for requests.\n * \\return The preempted request state.\n */\nRequestStateEntry PreemptLastRunningRequestStateEntry(\n    EngineState estate, const Array<Model>& models,\n    Optional<DraftTokenWorkspaceManager> draft_token_workspace_manager,\n    Optional<EventTraceRecorder> trace_recorder);\n\n/*!\n * \\brief Apply the logit processor to the logits and sample one token for each request.\n *\n * Both the parent request configurations and the child request configurations need to be provided.\n * The parent request configurations are used to process the logits, normalize the probabilities.\n * The child request configurations are used to sample the tokens.\n *\n * When the request doesn't have children, the parent and child configurations are the same.\n *\n * \\param logit_processor The logit processor to apply.\n * \\param sampler The sampler to sample tokens.\n * \\param logits The logits to process.\n * \\param generation_cfg The generation configurations of the requests.\n * \\param request_ids The request ids.\n * \\param mstates The model states of the requests.\n * \\param rngs The random generators of the requests.\n * \\param sample_indices The indices of the requests to sample.\n * \\param child_generation_cfg The generation configurations of the child requests.\n * \\param child_request_ids The request ids of the child requests.\n * \\param child_sample_indices The indices of the child requests to sample.\n * \\return The processed logits and the sampled results.\n */\nstd::pair<Tensor, std::vector<SampleResult>> ApplyLogitProcessorAndSample(\n    const LogitProcessor& logit_processor, const Sampler& sampler, const Tensor& logits,\n    const Array<GenerationConfig>& generation_cfg, const Array<String>& request_ids,\n    const Array<RequestModelState>& mstates, const std::vector<RandomGenerator*>& rngs,\n    const std::vector<int>& sample_indices, const Array<GenerationConfig>& child_generation_cfg,\n    const Array<String>& child_request_ids, const std::vector<int>& child_sample_indices);\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_ENGINE_ACTIONS_ACTION_COMMONS_H_\n"
  },
  {
    "path": "cpp/serve/engine_actions/auto_spec_decode.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_actions/auto_spec_decode.cc\n */\n\n#include <tvm/runtime/nvtx.h>\n\n#include <numeric>\n\n#include \"../config.h\"\n#include \"action.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\n/*!\n * \\brief The action that first makes a decision on whether to run speculative\n * decoding or normal mode batch decode, and then runs the selected actions.\n */\nclass AutoSpecDecodeActionObj : public EngineActionObj {\n public:\n  explicit AutoSpecDecodeActionObj(Array<EngineAction> spec_decode_actions,\n                                   Array<EngineAction> batch_decode_actions,\n                                   EngineConfig engine_config)\n      : spec_decode_actions_(std::move(spec_decode_actions)),\n        batch_decode_actions_(std::move(batch_decode_actions)),\n        engine_config_(std::move(engine_config)) {}\n\n  Array<Request> Step(EngineState estate) final {\n    int num_running_rsentries = estate->GetRunningRequestStateEntries().size();\n    if (num_running_rsentries == 0) {\n      return {};\n    }\n\n    // Calculate the draft length to use for the next round decode.\n    estate->spec_draft_length = CalculateDraftLength(estate, num_running_rsentries);\n    TVM_FFI_ICHECK_GE(estate->spec_draft_length, 0);\n    Array<Request> processed_requests;\n    // Use speculative decoding when the computed draft length is positive.\n    // Otherwise use normal mode batch decode.\n    Array<EngineAction> actions =\n        estate->spec_draft_length > 0 ? spec_decode_actions_ : batch_decode_actions_;\n    for (EngineAction action : actions) {\n      processed_requests = action->Step(estate);\n    }\n\n    // Reset the draft length.\n    estate->spec_draft_length = 0;\n    return processed_requests;\n  }\n\n private:\n  int CalculateDraftLength(EngineState estate, int num_running_rsentries) {\n    // Right now we use the fixed table to select the draft length (only based on\n    // the batch size). We will follow up to adopt powerful draft length selection.\n    int draft_length = 0;\n    if (num_running_rsentries < 10) {\n      draft_length = 4;\n    } else if (num_running_rsentries < 20) {\n      draft_length = 3;\n    } else if (num_running_rsentries < 30) {\n      draft_length = 2;\n    } else {\n      draft_length = 0;\n    }\n\n    int effective_batch_size = num_running_rsentries * (draft_length + 1);\n    return effective_batch_size > engine_config_->max_num_sequence ? 0 : draft_length;\n  }\n\n  /*! \\brief The speculative decode actions. */\n  Array<EngineAction> spec_decode_actions_;\n  /*! \\brief The normal mode decode actions. */\n  Array<EngineAction> batch_decode_actions_;\n  /*! \\brief The engine config. */\n  EngineConfig engine_config_;\n};\n\nEngineAction EngineAction::AutoSpecDecode(std::vector<EngineAction> spec_decode_actions_,\n                                          std::vector<EngineAction> batch_decode_actions_,\n                                          EngineConfig engine_config) {\n  return EngineAction(tvm::ffi::make_object<AutoSpecDecodeActionObj>(\n      Array<EngineAction>(spec_decode_actions_), Array<EngineAction>(batch_decode_actions_),\n      std::move(engine_config)));\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/engine_actions/batch_decode.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_actions/batch_decode.cc\n */\n\n#include <tvm/runtime/nvtx.h>\n\n#include <numeric>\n\n#include \"../../support/random.h\"\n#include \"../config.h\"\n#include \"../model.h\"\n#include \"../sampler/sampler.h\"\n#include \"action.h\"\n#include \"action_commons.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\n/*!\n * \\brief The action that runs one-step decode for requests in the\n * `running_queue` of engine state. Preempt low-priority requests\n * accordingly when it is impossible to decode all the running requests.\n * \\note The BatchDecode action **does not** take effect for speculative\n * decoding scenarios where there are multiple models. For speculative\n * decoding in the future, we will use other specific actions.\n */\nclass BatchDecodeActionObj : public EngineActionObj {\n public:\n  explicit BatchDecodeActionObj(Array<Model> models, Tokenizer tokenizer,\n                                LogitProcessor logit_processor, Sampler sampler,\n                                EngineConfig engine_config,\n                                Optional<EventTraceRecorder> trace_recorder)\n      : models_(std::move(models)),\n        tokenizer_(std::move(tokenizer)),\n        logit_processor_(std::move(logit_processor)),\n        sampler_(std::move(sampler)),\n        engine_config_(std::move(engine_config)),\n        trace_recorder_(std::move(trace_recorder)) {}\n\n  Array<Request> Step(EngineState estate) final {\n    // - Do not run decode when there is no running request.\n    if (estate->running_queue.empty()) {\n      return {};\n    }\n\n    // Preempt request state entries when decode cannot apply.\n    std::vector<RequestStateEntry> running_rsentries;\n    {\n      NVTXScopedRange nvtx_scope(\"BatchDecode getting requests\");\n      running_rsentries = estate->GetRunningRequestStateEntries();\n      while (!CanDecode(running_rsentries.size())) {\n        if (estate->prefix_cache->TryFreeMemory()) continue;\n        RequestStateEntry preempted =\n            PreemptLastRunningRequestStateEntry(estate, models_, std::nullopt, trace_recorder_);\n        if (preempted.same_as(running_rsentries.back())) {\n          running_rsentries.pop_back();\n        }\n      }\n      while (running_rsentries.size() >\n             std::min(static_cast<int64_t>(engine_config_->max_num_sequence),\n                      engine_config_->prefill_chunk_size)) {\n        running_rsentries.pop_back();\n      }\n    }\n\n    auto tstart = std::chrono::high_resolution_clock::now();\n\n    // NOTE: Right now we only support decode all the running request states at a time.\n    int num_rsentries = running_rsentries.size();\n    TVM_FFI_ICHECK_GT(num_rsentries, 0)\n        << \"There should be at least one request state entry that can run decode. \"\n           \"Possible failure reason: none of the prefill phase of the running requests is finished\";\n    TVM_FFI_ICHECK_LE(num_rsentries, engine_config_->max_num_sequence)\n        << \"The number of running requests exceeds the max number of sequence in EngineConfig. \"\n           \"Possible failure reason: the prefill action allows new sequence in regardless of the \"\n           \"max num sequence.\";\n    // Collect\n    // - the last committed token,\n    // - the request id,\n    // - the generation config,\n    // - the random number generator,\n    // of each request state entry.\n    std::vector<int> input_tokens;\n    std::vector<int> lengths;\n    Array<String> request_ids;\n    std::vector<int64_t> request_internal_ids;\n    Array<RequestModelState> mstates;\n    Array<GenerationConfig> generation_cfg;\n    std::vector<RandomGenerator*> rngs;\n\n    input_tokens.reserve(num_rsentries);\n    request_ids.reserve(num_rsentries);\n    request_internal_ids.reserve(num_rsentries);\n    mstates.reserve(num_rsentries);\n    generation_cfg.reserve(num_rsentries);\n    rngs.reserve(num_rsentries);\n\n    {\n      NVTXScopedRange nvtx_scope(\"BatchDecode setting batch info\");\n      for (const RequestStateEntry& rsentry : running_rsentries) {\n        auto mstate = rsentry->mstates[0];\n        TVM_FFI_ICHECK(mstate->num_tokens_for_next_decode > 0 &&\n                       mstate->num_tokens_for_next_decode <=\n                           static_cast<int>(mstate->committed_tokens.size()));\n\n        for (auto begin = mstate->committed_tokens.end() - mstate->num_tokens_for_next_decode;\n             begin != mstate->committed_tokens.end(); ++begin) {\n          input_tokens.push_back(begin->GetTokenId());\n        }\n\n        lengths.push_back(mstate->num_tokens_for_next_decode);\n        mstate->num_tokens_for_next_decode = 0;\n\n        request_ids.push_back(rsentry->request->id);\n        request_internal_ids.push_back(mstate->internal_id);\n        mstates.push_back(mstate);\n        generation_cfg.push_back(rsentry->request->generation_cfg);\n        rngs.push_back(&rsentry->rng);\n      }\n    }\n\n    // - Compute embeddings.\n    RECORD_EVENT(trace_recorder_, request_ids, \"start embedding\");\n    ObjectRef embeddings =\n        models_[0]->TokenEmbed({IntTuple(input_tokens.begin(), input_tokens.end())});\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish embedding\");\n\n    // - Invoke model decode.\n    // If every request only requires to process one token, batch decode kernel is called.\n    // Otherwise, batch prefill kernel is called.\n    bool is_every_request_single_token =\n        std::all_of(lengths.begin(), lengths.end(), [](int len) { return len == 1; });\n    RECORD_EVENT(trace_recorder_, request_ids, \"start decode\");\n    Tensor logits;\n    if (is_every_request_single_token) {\n      logits = models_[0]->BatchDecode(embeddings, request_internal_ids);\n      TVM_FFI_ICHECK_EQ(logits->ndim, 3);\n      TVM_FFI_ICHECK_EQ(logits->shape[0], num_rsentries);\n      TVM_FFI_ICHECK_EQ(logits->shape[1], 1);\n    } else {\n      logits = models_[0]->BatchPrefill(embeddings, request_internal_ids, lengths);\n      TVM_FFI_ICHECK_EQ(logits->ndim, 3);\n      TVM_FFI_ICHECK_EQ(logits->shape[0], 1);\n      TVM_FFI_ICHECK_EQ(logits->shape[1], num_rsentries);\n    }\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish decode\");\n\n    // - Update logits.\n    logits = logits.CreateView({num_rsentries, logits->shape[2]}, logits->dtype);\n    logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids);\n\n    // - Compute probability distributions.\n    Tensor probs_on_device =\n        logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids);\n\n    // - Commit the prefix cache changes from previous round of action.\n    // Note: we commit prefix cache changes here to overlap this commit with the GPU execution.\n    estate->prefix_cache->CommitSequenceExtention();\n\n    // - Sample tokens.\n    // Fill range [0, num_rsentries) into `sample_indices`.\n    std::vector<int> sample_indices(num_rsentries);\n    std::iota(sample_indices.begin(), sample_indices.end(), 0);\n    Tensor renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(\n        probs_on_device, sample_indices, request_ids, generation_cfg);\n    std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(\n        renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);\n    TVM_FFI_ICHECK_EQ(sample_results.size(), num_rsentries);\n\n    // - Update the committed tokens of states.\n    for (int i = 0; i < num_rsentries; ++i) {\n      auto mstate = mstates[i];\n\n      if (!mstate->require_retokenization_in_next_decode) {\n        mstates[i]->CommitToken(sample_results[i]);\n        // live update the output metrics\n        running_rsentries[i]->rstate->metrics.completion_tokens += 1;\n      } else {\n        // Retokenize and commit tokens.\n        CommitTokenMayRetokenize(running_rsentries[i], mstate, sample_results[i]);\n        mstate->require_retokenization_in_next_decode = false;\n      }\n\n      running_rsentries[i]->rstate->metrics.decode_tokens += lengths[i];\n    }\n\n    double elapsed_time;\n    {\n      NVTXScopedRange nvtx_scope(\"BatchDecode get time\");\n      auto tend = std::chrono::high_resolution_clock::now();\n      elapsed_time = static_cast<double>((tend - tstart).count()) / 1e9;\n    }\n    estate->metrics.engine_decode_time_sum += elapsed_time;\n    estate->metrics.UpdateDecodeTimeByBatchSize(num_rsentries, elapsed_time);\n\n    return estate->running_queue;\n  }\n\n private:\n  /*! \\brief Check if the input request state entries can be decoded under conditions. */\n  bool CanDecode(int num_rsentries) {\n    int num_available_pages = models_[0]->GetNumAvailablePages();\n    return num_rsentries <= num_available_pages;\n  }\n\n  /*!\n   * \\brief Retokenize the past tokens with a new token.\n   * \\param mstate The model state.\n   * \\param token_id The new token id.\n   * \\param max_rollback_tokens The maximum number of tokens to rollback.\n   * \\return The number of tokens to rollback and the new tokens.\n   */\n  std::pair<int, std::vector<int32_t>> RetokenizeWithNewToken(RequestModelState mstate,\n                                                              int32_t token_id,\n                                                              int max_rollback_tokens) {\n    // Step 1. Get past tokens\n    // past_tokens = mstate[-max_rollback_tokens:]\n    // past_string = detokenize(past_tokens)\n    const auto& token_table = tokenizer_->PostProcessedTokenTable();\n    std::vector<int32_t> past_tokens;\n    std::string past_string;\n    auto past_begin_it = mstate->committed_tokens.size() >= max_rollback_tokens\n                             ? mstate->committed_tokens.end() - max_rollback_tokens\n                             : mstate->committed_tokens.begin();\n    for (auto it = past_begin_it; it != mstate->committed_tokens.end(); ++it) {\n      past_tokens.push_back(it->GetTokenId());\n      past_string += token_table[it->GetTokenId()];\n    }\n\n    // Step 2. Retokenize\n    // Compare tokenize(past_string + new_string) and past_tokens\n    auto new_tokens = tokenizer_->EncodeNoPrependSpace(past_string + token_table[token_id]);\n\n    int first_differ_idx = past_tokens.size();\n    for (int i = 0; i < static_cast<int>(past_tokens.size()); ++i) {\n      if (i == static_cast<int>(new_tokens.size()) || past_tokens[i] != new_tokens[i]) {\n        first_differ_idx = i;\n        break;\n      }\n    }\n\n    return {past_tokens.size() - first_differ_idx,\n            std::vector<int32_t>(new_tokens.begin() + first_differ_idx, new_tokens.end())};\n  }\n\n  /*!\n   * \\brief Commit the token and may retokenize the past tokens.\n   * \\param rsentry The request state entry.\n   * \\param mstate The model state.\n   * \\param sample_result The sampled token.\n   */\n  void CommitTokenMayRetokenize(RequestStateEntry rsentry, RequestModelState mstate,\n                                const SampleResult& sample_result) {\n    auto generation_cfg = rsentry->request->generation_cfg;\n    // 1. If EOS token is generated, jump commit it\n    if (!generation_cfg->debug_config.ignore_eos &&\n        std::any_of(generation_cfg->stop_token_ids.begin(), generation_cfg->stop_token_ids.end(),\n                    [&](int32_t token) { return token == sample_result.GetTokenId(); })) {\n      mstate->CommitToken(sample_result);\n      rsentry->rstate->metrics.completion_tokens += 1;\n      return;\n    }\n\n    // 2. Check retokenization\n    const auto& committed_tokens = mstate->committed_tokens;\n    auto [rollback_cnt, new_tokens] =\n        RetokenizeWithNewToken(mstate, sample_result.GetTokenId(), MAX_ROLLBACK_TOKENS_);\n\n    // 3. Handle output when retokenization happens\n    if (rollback_cnt >\n        static_cast<int>(committed_tokens.size()) - rsentry->next_callback_token_pos) {\n      const auto& token_table = tokenizer_->PostProcessedTokenTable();\n      for (auto i = rsentry->next_callback_token_pos; i < committed_tokens.size(); ++i) {\n        auto token_id = committed_tokens[i].GetTokenId();\n        rsentry->extra_prefix_string += token_table[token_id];\n      }\n      rsentry->extra_prefix_string += token_table[sample_result.GetTokenId()];\n      rsentry->next_callback_token_pos = static_cast<int>(committed_tokens.size()) - rollback_cnt +\n                                         static_cast<int>(new_tokens.size());\n    }\n\n    if (rollback_cnt > 0) {\n      mstate->RollbackTokens(rollback_cnt);\n      models_[0]->PopNFromKVCache(mstate->internal_id, rollback_cnt);\n    }\n\n    for (auto token_id : new_tokens) {\n      mstate->CommitToken({{token_id, 1.0}, {}});\n    }\n\n    rsentry->rstate->metrics.completion_tokens +=\n        static_cast<int>(new_tokens.size()) - rollback_cnt;\n  }\n\n  /*!\n   * \\brief The model to run decode in. When there are multiple\n   * models, the `Step` function of the created action will not take effect.\n   */\n  Array<Model> models_;\n  /*! \\brief The tokenizer of the engine. */\n  Tokenizer tokenizer_;\n  /*! \\brief The logit processor. */\n  LogitProcessor logit_processor_;\n  /*! \\brief The sampler to sample new tokens. */\n  Sampler sampler_;\n  /*! \\brief The engine config. */\n  EngineConfig engine_config_;\n  /*! \\brief Event trace recorder. */\n  Optional<EventTraceRecorder> trace_recorder_;\n  /*! \\brief The maximum number of tokens to retokenize and may be rolled back. */\n  const int MAX_ROLLBACK_TOKENS_ = 10;\n};\n\nEngineAction EngineAction::BatchDecode(Array<Model> models, Tokenizer tokenizer,\n                                       LogitProcessor logit_processor, Sampler sampler,\n                                       EngineConfig engine_config,\n                                       Optional<EventTraceRecorder> trace_recorder) {\n  return EngineAction(tvm::ffi::make_object<BatchDecodeActionObj>(\n      std::move(models), std::move(tokenizer), std::move(logit_processor), std::move(sampler),\n      std::move(engine_config), std::move(trace_recorder)));\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/engine_actions/batch_draft.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_actions/batch_draft.cc\n */\n\n#include <numeric>\n\n#include \"../config.h\"\n#include \"../model.h\"\n#include \"../sampler/sampler.h\"\n#include \"action.h\"\n#include \"action_commons.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\n/*!\n * \\brief The action that runs draft proposal for requests in the\n * `running_queue` of engine state. Preempt low-priority requests\n * accordingly when it is impossible to decode all the running requests.\n */\nclass BatchDraftActionObj : public EngineActionObj {\n public:\n  explicit BatchDraftActionObj(Array<Model> models, LogitProcessor logit_processor, Sampler sampler,\n                               std::vector<ModelWorkspace> model_workspaces,\n                               DraftTokenWorkspaceManager draft_token_workspace_manager,\n                               EngineConfig engine_config,\n                               Optional<EventTraceRecorder> trace_recorder)\n      : models_(std::move(models)),\n        logit_processor_(std::move(logit_processor)),\n        sampler_(std::move(sampler)),\n        model_workspaces_(std::move(model_workspaces)),\n        draft_token_workspace_manager_(std::move(draft_token_workspace_manager)),\n        engine_config_(std::move(engine_config)),\n        trace_recorder_(std::move(trace_recorder)) {}\n\n  Array<Request> Step(EngineState estate) final {\n    // - Only run spec decode when there are two models (llm+ssm) and >=1 running requests.\n    if (models_.size() != 2 || estate->running_queue.empty()) {\n      return {};\n    }\n\n    // Preempt request state entries when decode cannot apply.\n    std::vector<RequestStateEntry> running_rsentries = estate->GetRunningRequestStateEntries();\n    while (!CanDecode(running_rsentries.size())) {\n      if (estate->prefix_cache->TryFreeMemory()) continue;\n      RequestStateEntry preempted = PreemptLastRunningRequestStateEntry(\n          estate, models_, draft_token_workspace_manager_, trace_recorder_);\n      if (preempted.same_as(running_rsentries.back())) {\n        running_rsentries.pop_back();\n      }\n    }\n    while (running_rsentries.size() * (engine_config_->spec_draft_length + 1) >\n           std::min(static_cast<int64_t>(engine_config_->max_num_sequence),\n                    engine_config_->prefill_chunk_size)) {\n      running_rsentries.pop_back();\n    }\n\n    auto tstart = std::chrono::high_resolution_clock::now();\n\n    int num_rsentries = running_rsentries.size();\n    TVM_FFI_ICHECK_GT(num_rsentries, 0)\n        << \"There should be at least one request state entry that can run decode. \"\n           \"Possible failure reason: none of the prefill phase of the running requests is finished\";\n    TVM_FFI_ICHECK_LE(num_rsentries, engine_config_->max_num_sequence)\n        << \"The number of running requests exceeds the max number of sequence in EngineConfig. \"\n           \"Possible failure reason: the prefill action allows new sequence in regardless of the \"\n           \"max num sequence.\";\n    Array<String> request_ids;\n    std::vector<int64_t> request_internal_ids;\n    Array<String> request_ids_per_leaf_node;\n    Array<GenerationConfig> generation_cfg;\n    Array<GenerationConfig> generation_cfg_for_logitproc;\n    std::vector<RandomGenerator*> rngs;\n    std::vector<std::vector<int>> draft_token_indices;\n    // Number of input tokens for each request. Each request can have multiple leaf tokens for the\n    // next forward when multiple tokens are drafted.\n    std::vector<int> cum_num_tokens;\n    std::vector<int64_t> token_tree_parent_ptr;\n    request_ids.reserve(num_rsentries);\n    request_internal_ids.reserve(num_rsentries);\n    generation_cfg.reserve(num_rsentries);\n    generation_cfg_for_logitproc.reserve(num_rsentries);\n    draft_token_indices.reserve(num_rsentries);\n    cum_num_tokens.reserve(num_rsentries + 1);\n    for (const RequestStateEntry& rsentry : running_rsentries) {\n      request_ids.push_back(rsentry->request->id);\n      request_internal_ids.push_back(rsentry->mstates[0]->internal_id);\n    }\n\n    TVM_FFI_ICHECK_GT(estate->spec_draft_length, 0)\n        << \"The speculative decoding draft length must be positive.\";\n    // The first model doesn't get involved in draft proposal.\n    for (int model_id = 1; model_id < static_cast<int>(models_.size()); ++model_id) {\n      // Collect\n      // - the last committed token,\n      // - the request model state of each request,\n      // - the number of tokens for each request to send into the model (it may\n      // be more than one if the draft model is lagging behind the main model, when\n      // the engine switches from normal batch decode mode to speculative decoding mode).\n      std::vector<int> input_tokens;\n      Array<RequestModelState> mstates;\n      std::vector<int> input_lengths;\n      input_tokens.reserve(num_rsentries);\n      mstates.reserve(num_rsentries);\n      input_lengths.reserve(num_rsentries);\n      for (const RequestStateEntry& rsentry : running_rsentries) {\n        mstates.push_back(rsentry->mstates[model_id]);\n      }\n      // \"Draft length\" rounds of draft proposal.\n      for (int draft_id = 0; draft_id < estate->spec_draft_length; ++draft_id) {\n        auto tdraft_start = std::chrono::high_resolution_clock::now();\n        // prepare new input tokens\n        input_tokens.clear();\n        input_lengths.clear();\n        token_tree_parent_ptr.clear();\n        generation_cfg.clear();\n        generation_cfg_for_logitproc.clear();\n        rngs.clear();\n        cum_num_tokens.clear();\n        cum_num_tokens.push_back(0);\n        request_ids_per_leaf_node.clear();\n        std::vector<int> draft_token_parent_idx;\n        draft_token_indices.clear();\n\n        if (draft_id == 0) {\n          // Compute the total length that needs to be processed by the draft model,\n          // including the lagging-behind part of hte draft model.\n          // When the total length to be processed is larger than the prefill chunk\n          // size, we must do the prefill with multiple rounds by chunk.\n          int total_length = 0;\n          for (int i = 0; i < num_rsentries; ++i) {\n            TVM_FFI_ICHECK_LE(mstates[i]->committed_tokens.size(),\n                              running_rsentries[i]->mstates[0]->committed_tokens.size());\n            total_length += running_rsentries[i]->mstates[0]->committed_tokens.size() -\n                            mstates[i]->committed_tokens.size() + 1;\n          }\n          if (total_length > engine_config_->prefill_chunk_size) {\n            PrefillLaggedTokensByChunk(mstates, running_rsentries, models_[model_id],\n                                       total_length - engine_config_->prefill_chunk_size);\n          }\n        }\n\n        for (int i = 0; i < num_rsentries; ++i) {\n          int num_leaf_nodes = 0;\n          // Starting from last committed tokens\n          if (draft_id == 0) {\n            TVM_FFI_ICHECK_LE(mstates[i]->committed_tokens.size(),\n                              running_rsentries[i]->mstates[0]->committed_tokens.size());\n            TVM_FFI_ICHECK_EQ(mstates[i]->num_tokens_for_next_decode, 1);\n            input_tokens.push_back(mstates[i]->committed_tokens.back().GetTokenId());\n            input_lengths.push_back(running_rsentries[i]->mstates[0]->committed_tokens.size() -\n                                    mstates[i]->committed_tokens.size() + 1);\n            for (size_t j = mstates[i]->committed_tokens.size();\n                 j < running_rsentries[i]->mstates[0]->committed_tokens.size(); ++j) {\n              // This draft model is lagging behind the main model.\n              // It may happen when the engine just switches from the normal batch decode\n              // mode to the speculative decoding mode.\n              // In this case, we need to prefill the misaligned tokens into the draft model.\n              mstates[i]->CommitToken(running_rsentries[i]->mstates[0]->committed_tokens[j]);\n              input_tokens.push_back(\n                  running_rsentries[i]->mstates[0]->committed_tokens[j].GetTokenId());\n            }\n            mstates[i]->num_tokens_for_next_decode = 0;\n            draft_token_indices.emplace_back(std::vector<int>{-1});\n            rngs.push_back(&running_rsentries[i]->rng);\n            draft_token_parent_idx.push_back(-1);\n            request_ids_per_leaf_node.push_back(request_ids[i]);\n            num_leaf_nodes = 1;\n            cum_num_tokens.push_back(cum_num_tokens.back() + 1);\n          } else {\n            TVM_FFI_ICHECK_EQ(mstates[i]->committed_tokens.size(),\n                              running_rsentries[i]->mstates[0]->committed_tokens.size());\n            TVM_FFI_ICHECK(!mstates[i]->draft_output_tokens.empty());\n            draft_token_indices.emplace_back(std::vector<int>{});\n            // Get all leaf nodes\n            for (int j = 0; j < static_cast<int>(mstates[i]->draft_output_tokens.size()); ++j) {\n              if (mstates[i]->draft_token_first_child_idx[j] == -1) {\n                int64_t parent_idx = mstates[i]->draft_token_parent_idx[j];\n                token_tree_parent_ptr.push_back(parent_idx);\n                input_tokens.push_back(mstates[i]->draft_output_tokens[j].GetTokenId());\n                draft_token_indices.back().push_back(j);\n                rngs.push_back(&running_rsentries[i]->rng);\n                num_leaf_nodes++;\n                request_ids_per_leaf_node.push_back(request_ids[i]);\n                draft_token_parent_idx.push_back(j);\n              }\n            }\n            input_lengths.push_back(num_leaf_nodes);\n            cum_num_tokens.push_back(cum_num_tokens.back() + input_lengths.back());\n          }\n          GenerationConfig generation_cfg_for_draft = [&]() {\n            if (engine_config_->spec_tree_width == 1) {\n              return mstates[i]->request->generation_cfg;\n            }\n            auto spec_generation_cfg = tvm::ffi::make_object<GenerationConfigNode>(\n                *(mstates[i]->request->generation_cfg.get()));\n            spec_generation_cfg->top_logprobs = engine_config_->spec_tree_width;\n            spec_generation_cfg->logprobs = true;\n            spec_generation_cfg->temperature = 1.0;\n            return GenerationConfig(spec_generation_cfg);\n          }();\n          for (int j = 0; j < num_leaf_nodes; ++j) {\n            generation_cfg.push_back(generation_cfg_for_draft);\n          }\n          generation_cfg_for_logitproc.push_back(generation_cfg_for_draft);\n        }\n\n        // - Compute embeddings.\n        RECORD_EVENT(trace_recorder_, request_ids, \"start proposal embedding\");\n        TVM_FFI_ICHECK_LE(input_tokens.size(), engine_config_->prefill_chunk_size);\n        ObjectRef embeddings =\n            models_[model_id]->TokenEmbed({IntTuple{input_tokens.begin(), input_tokens.end()}});\n        RECORD_EVENT(trace_recorder_, request_ids, \"finish proposal embedding\");\n\n        // - Invoke model decode.\n        RECORD_EVENT(trace_recorder_, request_ids, \"start proposal decode\");\n        Tensor logits{nullptr};\n\n        if (input_tokens.size() == num_rsentries) {\n          // Each request entry only has one token to feed into the draft model.\n          logits = models_[model_id]->BatchDecode(embeddings, request_internal_ids);\n          TVM_FFI_ICHECK_EQ(logits->ndim, 3);\n          TVM_FFI_ICHECK_EQ(logits->shape[0], num_rsentries);\n          TVM_FFI_ICHECK_EQ(logits->shape[1], 1);\n        } else if (draft_id == 0) {\n          // There exists some request entry which has more than one token to feed.\n          // It may happen when the engine just switches from the normal batch decode\n          // mode to the speculative decoding mode.\n          logits = models_[model_id]->BatchPrefill(embeddings, request_internal_ids, input_lengths);\n          TVM_FFI_ICHECK_EQ(logits->ndim, 3);\n          TVM_FFI_ICHECK_EQ(logits->shape[0], 1);\n          TVM_FFI_ICHECK_EQ(logits->shape[1], num_rsentries);\n        } else {\n          TVM_FFI_ICHECK_GT(engine_config_->spec_tree_width, 1);\n          logits = models_[model_id]->BatchTreeDecode(embeddings, request_internal_ids,\n                                                      input_lengths, token_tree_parent_ptr);\n          TVM_FFI_ICHECK_EQ(logits->ndim, 3);\n          TVM_FFI_ICHECK_EQ(logits->shape[0], cum_num_tokens.back());\n          TVM_FFI_ICHECK_EQ(logits->shape[1], 1);\n        }\n        TVM_FFI_ICHECK_EQ(input_lengths.size(), num_rsentries);\n        RECORD_EVENT(trace_recorder_, request_ids, \"finish proposal decode\");\n\n        // - Update logits.\n        logits = logits.CreateView({cum_num_tokens.back(), logits->shape[2]}, logits->dtype);\n\n        logit_processor_->InplaceUpdateLogits(logits, generation_cfg_for_logitproc, mstates,\n                                              request_ids, &cum_num_tokens, &mstates,\n                                              &draft_token_indices);\n\n        // - Compute probability distributions.\n        Tensor probs_on_device = logit_processor_->ComputeProbsFromLogits(\n            logits, generation_cfg_for_logitproc, request_ids, &cum_num_tokens);\n\n        // - Commit the prefix cache changes from previous round of action.\n        // Note: we commit prefix cache changes here to overlap this commit with the GPU execution.\n        estate->prefix_cache->CommitSequenceExtention();\n\n        // - Sample tokens.\n        // Fill range [0, num_rsentries) into `sample_indices`.\n        std::vector<int> sample_indices(cum_num_tokens.back());\n        std::iota(sample_indices.begin(), sample_indices.end(), 0);\n        std::vector<Tensor> prob_dist;\n        Tensor renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(\n            probs_on_device, sample_indices, request_ids_per_leaf_node, generation_cfg);\n        std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(\n            renormalized_probs, sample_indices, request_ids_per_leaf_node, generation_cfg, rngs);\n        TVM_FFI_ICHECK_EQ(sample_results.size(), cum_num_tokens.back());\n\n        // - Add draft token to the state.\n        draft_token_workspace_manager_->AllocSlots(cum_num_tokens.back(), &draft_token_slots_);\n        models_[model_id]->ScatterDraftProbs(probs_on_device, draft_token_slots_,\n                                             &model_workspaces_[0].draft_probs_storage);\n        for (int i = 0; i < num_rsentries; ++i) {\n          for (int j = cum_num_tokens[i]; j < cum_num_tokens[i + 1]; ++j) {\n            int parent_idx = draft_token_parent_idx[j];\n            if (engine_config_->spec_tree_width == 1) {\n              mstates[i]->AddDraftToken(sample_results[j], draft_token_slots_[j], parent_idx);\n              continue;\n            }\n            for (int k = 0; k < sample_results[j].top_prob_tokens.size(); ++k) {\n              SampleResult top_k_token{sample_results[j].top_prob_tokens[k]};\n              mstates[i]->AddDraftToken(top_k_token, draft_token_slots_[j], parent_idx);\n            }\n          }\n        }\n\n        auto tdraft_end = std::chrono::high_resolution_clock::now();\n        estate->metrics.UpdateDraftTimeByBatchSize(\n            num_rsentries, static_cast<double>((tdraft_end - tdraft_start).count()) / 1e9);\n      }\n    }\n\n    auto tend = std::chrono::high_resolution_clock::now();\n    estate->metrics.engine_decode_time_sum += static_cast<double>((tend - tstart).count()) / 1e9;\n\n    return {};\n  }\n\n private:\n  /*! \\brief Check if the input requests can be decoded under conditions. */\n  bool CanDecode(int num_rsentries) {\n    // The first model is not involved in draft proposal.\n    for (int model_id = 1; model_id < static_cast<int>(models_.size()); ++model_id) {\n      // Check if the model has enough available pages.\n      int num_available_pages = models_[model_id]->GetNumAvailablePages();\n      if (num_rsentries > num_available_pages) {\n        return false;\n      }\n    }\n    return true;\n  }\n\n  void PrefillLaggedTokensByChunk(const Array<RequestModelState>& mstates,\n                                  const std::vector<RequestStateEntry>& running_rsentries,\n                                  Model model, int remaining_prefill_length) {\n    int num_rsentries = mstates.size();\n    std::vector<int> input_tokens;\n    std::vector<int64_t> request_internal_ids;\n    std::vector<int> lengths;\n    input_tokens.reserve(engine_config_->prefill_chunk_size);\n    request_internal_ids.reserve(num_rsentries);\n    lengths.reserve(num_rsentries);\n\n    auto f_run_prefill = [&model, &input_tokens, &request_internal_ids, &lengths]() {\n      ObjectRef embeddings =\n          model->TokenEmbed({IntTuple{input_tokens.begin(), input_tokens.end()}});\n      model->BatchPrefill(embeddings, request_internal_ids, lengths);\n    };\n\n    for (int i = 0; i < num_rsentries; ++i) {\n      int prefill_length =\n          std::min({static_cast<int>(running_rsentries[i]->mstates[0]->committed_tokens.size() -\n                                     mstates[i]->committed_tokens.size()),\n                    static_cast<int>(engine_config_->prefill_chunk_size - input_tokens.size()),\n                    remaining_prefill_length});\n      if (prefill_length == 0) {\n        // This rsentry is done.\n        continue;\n      }\n\n      TVM_FFI_ICHECK(!mstates[i]->committed_tokens.empty());\n      for (size_t j = mstates[i]->committed_tokens.size();\n           j < running_rsentries[i]->mstates[0]->committed_tokens.size(); ++j) {\n        // Commit the lagging-behind tokens to the draft model.\n        mstates[i]->CommitToken(running_rsentries[i]->mstates[0]->committed_tokens[j - 1]);\n        input_tokens.push_back(\n            running_rsentries[i]->mstates[0]->committed_tokens[j - 1].GetTokenId());\n      }\n      lengths.push_back(prefill_length);\n      request_internal_ids.push_back(running_rsentries[i]->mstates[0]->internal_id);\n      mstates[i]->num_tokens_for_next_decode = 1;\n      remaining_prefill_length -= prefill_length;\n      if (remaining_prefill_length == 0) {\n        // All rsentries are done.\n        break;\n      }\n\n      if (input_tokens.size() == engine_config_->prefill_chunk_size) {\n        // Run prefill if the pending part total length reaches the prefill chunk size.\n        f_run_prefill();\n        input_tokens.clear();\n        request_internal_ids.clear();\n        lengths.clear();\n        --i;\n        continue;\n      }\n    }\n\n    if (!input_tokens.empty()) {\n      f_run_prefill();\n    }\n  }\n\n  /*! \\brief The model to run draft generation in speculative decoding. */\n  Array<Model> models_;\n  /*! \\brief The logit processor. */\n  LogitProcessor logit_processor_;\n  /*! \\brief The sampler to sample new tokens. */\n  Sampler sampler_;\n  /*! \\brief The model workspaces. */\n  std::vector<ModelWorkspace> model_workspaces_;\n  /*! \\brief The draft token workspace manager. */\n  DraftTokenWorkspaceManager draft_token_workspace_manager_;\n  /*! \\brief The engine config. */\n  EngineConfig engine_config_;\n  /*! \\brief Event trace recorder. */\n  Optional<EventTraceRecorder> trace_recorder_;\n  /*! \\brief Temporary buffer to store the slots of the current draft tokens */\n  std::vector<int> draft_token_slots_;\n};\n\nEngineAction EngineAction::BatchDraft(Array<Model> models, LogitProcessor logit_processor,\n                                      Sampler sampler, std::vector<ModelWorkspace> model_workspaces,\n                                      DraftTokenWorkspaceManager draft_token_workspace_manager,\n                                      EngineConfig engine_config,\n                                      Optional<EventTraceRecorder> trace_recorder) {\n  return EngineAction(tvm::ffi::make_object<BatchDraftActionObj>(\n      std::move(models), std::move(logit_processor), std::move(sampler),\n      std::move(model_workspaces), std::move(draft_token_workspace_manager),\n      std::move(engine_config), std::move(trace_recorder)));\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/engine_actions/batch_jumpforward.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_actions/batch_verify.cc\n */\n\n#include <tvm/runtime/nvtx.h>\n#include <tvm/runtime/threading_backend.h>\n\n#include <cmath>\n#include <exception>\n\n#include \"../config.h\"\n#include \"../model.h\"\n#include \"../sampler/sampler.h\"\n#include \"action.h\"\n#include \"action_commons.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\n/*!\n * \\brief The action that runs verification for requests in the\n * `running_queue` of engine state. Preempt low-priority requests\n * accordingly when it is impossible to decode all the running requests.\n */\nclass BatchJumpForwardActionObj : public EngineActionObj {\n public:\n  explicit BatchJumpForwardActionObj(Array<Model> models, Tokenizer tokenizer,\n                                     Optional<EventTraceRecorder> trace_recorder)\n      : models_(std::move(models)),\n        tokenizer_(tokenizer),\n        trace_recorder_(std::move(trace_recorder)) {}\n\n  Array<Request> Step(EngineState estate) final {\n    // - Do not run decode when there are multiple models or no running requests.\n    if (models_.size() > 1 || estate->running_queue.empty()) {\n      return {};\n    }\n\n    // Preempt request state entries when jump-forward decoding cannot apply.\n    std::vector<RequestStateEntry> running_rsentries;\n    {\n      NVTXScopedRange nvtx_scope(\"BatchJumpForward getting requests\");\n      running_rsentries = estate->GetRunningRequestStateEntries();\n      while (!CheckMemForJumpForward(running_rsentries.size())) {\n        if (estate->prefix_cache->TryFreeMemory()) continue;\n        RequestStateEntry preempted =\n            PreemptLastRunningRequestStateEntry(estate, models_, std::nullopt, trace_recorder_);\n        if (preempted.same_as(running_rsentries.back())) {\n          running_rsentries.pop_back();\n        }\n      }\n    }\n\n    if (running_rsentries.empty()) {\n      return {};\n    }\n\n    auto tstart = std::chrono::high_resolution_clock::now();\n\n    for (auto rsentry : running_rsentries) {\n      if (!CanJumpForward(rsentry)) {\n        continue;\n      }\n\n      auto mstate = rsentry->mstates[0];\n      auto jump_forward_str = mstate->grammar_matcher->FindJumpForwardString();\n\n      if (jump_forward_str.empty()) {\n        continue;\n      }\n\n      auto [rollback_cnt, new_tokens, new_string] =\n          RetokenizeWithNewString(mstate, jump_forward_str, MAX_ROLLBACK_TOKENS_);\n\n      HandleRollback(rsentry, mstate, rollback_cnt, new_tokens, new_string);\n\n      // Commit new tokens (kv cache is handled in the next decode)\n      for (auto token_id : new_tokens) {\n        mstate->CommitToken({{token_id, 1.0}, {}});\n      }\n\n      mstate->require_retokenization_in_next_decode = true;\n\n      // Update metrics\n      rsentry->rstate->metrics.jump_forward_tokens +=\n          std::max(static_cast<int>(new_tokens.size()) - rollback_cnt, 0);\n\n      rsentry->rstate->metrics.completion_tokens +=\n          static_cast<int>(new_tokens.size()) - rollback_cnt;\n    }\n\n    auto tend = std::chrono::high_resolution_clock::now();\n    estate->metrics.engine_jump_forward_time_sum +=\n        static_cast<double>((tend - tstart).count()) / 1e9;\n\n    return {};\n  }\n\n private:\n  /*! \\brief Check if jump-forward decoding can be executed without exceeding the memory limit. */\n  bool CheckMemForJumpForward(int num_rsentries) {\n    static constexpr int MAX_AVG_JUMPFORWARD_PAGES_PER_REQUEST = 10;\n    int num_available_pages = models_[0]->GetNumAvailablePages();\n    return num_rsentries * MAX_AVG_JUMPFORWARD_PAGES_PER_REQUEST <= num_available_pages;\n  }\n\n  /*! \\brief Check if the jump-forward can be executed. When logprobs is requested, or the\n   * grammar state matcher is not defined, jump-forward is not executed. */\n  bool CanJumpForward(const RequestStateEntry& rsentry) {\n    if (rsentry->request->generation_cfg->debug_config.grammar_execution_mode !=\n        GrammarExecutionMode::kJumpForward) {\n      return false;\n    }\n    if (rsentry->request->generation_cfg->logprobs) {\n      return false;\n    }\n    if (!rsentry->mstates[0]->grammar_matcher) {\n      return false;\n    }\n    return true;\n  }\n\n  /*!\n   * \\brief Retokenize the input string with a new string.\n   * \\param mstate The model state.\n   * \\param new_string The new string to append.\n   * \\param max_rollback_tokens The maximum number of tokens to rollback.\n   * \\return The number of tokens to rollback, the new tokens and a delta string of output (equal to\n   * new_string if no cutoff happens; shorter than new_string if cutoff happens).\n   */\n  std::tuple<int, std::vector<int32_t>, std::string> RetokenizeWithNewString(\n      RequestModelState mstate, const std::string& new_string, int max_rollback_tokens) {\n    // Step 1. Get past tokens\n    // past_tokens = mstate[-max_rollback_tokens:]\n    // past_string = detokenize(past_tokens)\n    const auto& token_table = tokenizer_->PostProcessedTokenTable();\n    std::vector<int32_t> past_tokens;\n    std::string past_string;\n    auto past_begin_it = mstate->committed_tokens.size() >= max_rollback_tokens\n                             ? mstate->committed_tokens.end() - max_rollback_tokens\n                             : mstate->committed_tokens.begin();\n    for (auto it = past_begin_it; it != mstate->committed_tokens.end(); ++it) {\n      past_tokens.push_back(it->GetTokenId());\n      past_string += token_table[it->GetTokenId()];\n    }\n\n    // Step 2. Retokenize\n    // Compare tokenize(past_string + new_string) and past_tokens\n    auto new_tokens = tokenizer_->EncodeNoPrependSpace(past_string + new_string);\n    auto delta_string = new_string;\n\n    // Pop last token if it is a prefix of another token. That's because such tokens will often\n    // be rolled back in the next decode, which disturbs the distribution, so we will avoid\n    // generating them.\n    if (tokenizer_->GetPrefixTokenMask()[new_tokens.back()]) {\n      auto last_token = token_table[new_tokens.back()];\n      if (last_token.length() >= new_string.length()) {\n        return {0, {}, \"\"};\n      }\n\n      delta_string = delta_string.substr(0, delta_string.length() - last_token.length());\n      new_tokens.pop_back();\n    }\n\n    int first_differ_idx = past_tokens.size();\n    for (int i = 0; i < static_cast<int>(past_tokens.size()); ++i) {\n      if (i == static_cast<int>(new_tokens.size()) || past_tokens[i] != new_tokens[i]) {\n        first_differ_idx = i;\n        break;\n      }\n    }\n\n    return {past_tokens.size() - first_differ_idx,\n            std::vector<int32_t>(new_tokens.begin() + first_differ_idx, new_tokens.end()),\n            delta_string};\n  }\n\n  /*!\n   * \\brief Handle rollback for the stream output, the model state and the kv cache.\n   * \\param rsentry The request state entry.\n   * \\param mstate The model state.\n   * \\param rollback_cnt The number of tokens to rollback.\n   * \\param new_tokens The new tokens. Useful for the stream output.\n   * \\param new_string The delta string of output. Useful for the stream output.\n   */\n  void HandleRollback(const RequestStateEntry& rsentry, RequestModelState mstate, int rollback_cnt,\n                      const std::vector<int32_t>& new_tokens, const std::string& new_string) {\n    // 1. Handle rollback for the stream output\n    if (rollback_cnt >\n        static_cast<int>(mstate->committed_tokens.size()) - rsentry->next_callback_token_pos) {\n      const auto& token_table = tokenizer_->PostProcessedTokenTable();\n      for (auto i = rsentry->next_callback_token_pos; i < mstate->committed_tokens.size(); ++i) {\n        auto token_id = mstate->committed_tokens[i].GetTokenId();\n        rsentry->extra_prefix_string += token_table[token_id];\n      }\n      rsentry->extra_prefix_string += new_string;\n      rsentry->next_callback_token_pos = static_cast<int>(mstate->committed_tokens.size()) -\n                                         rollback_cnt + static_cast<int>(new_tokens.size());\n    }\n\n    // 2. Handle rollback for the model state\n    if (rollback_cnt > 0) {\n      mstate->RollbackTokens(rollback_cnt);\n    }\n\n    // 3. Handle rollback for the kv cache\n    if (rollback_cnt > mstate->num_tokens_for_next_decode) {\n      models_[0]->PopNFromKVCache(mstate->internal_id,\n                                  rollback_cnt - mstate->num_tokens_for_next_decode);\n      mstate->num_tokens_for_next_decode = 0;\n    } else {\n      mstate->num_tokens_for_next_decode -= rollback_cnt;\n    }\n  }\n\n  /*!\n   * \\brief The model to run jump-forward decoding. When there are multiple\n   * models, the `Step` function of the created action will not take effect.\n   */\n  Array<Model> models_;\n  /*! \\brief Tokenizer for retokenization. */\n  Tokenizer tokenizer_;\n  /*! \\brief Event trace recorder. */\n  Optional<EventTraceRecorder> trace_recorder_;\n  /*! \\brief The maximum number of tokens to rollback. */\n  const int MAX_ROLLBACK_TOKENS_ = 10;\n};\n\nEngineAction EngineAction::BatchJumpForward(Array<Model> models, Tokenizer tokenizer,\n                                            Optional<EventTraceRecorder> trace_recorder) {\n  return EngineAction(tvm::ffi::make_object<BatchJumpForwardActionObj>(\n      std::move(models), std::move(tokenizer), std::move(trace_recorder)));\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/engine_actions/batch_prefill_base.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_actions/batch_prefill_base.h\n */\n\n#include \"batch_prefill_base.h\"\n\n#include <numeric>\n\n#include \"../../support/json_parser.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nbool HasPrefillSpace(int num_required_pages, bool sliding_window_enabled, int new_batch_size,\n                     int num_available_pages, int current_total_seq_len, int total_input_length,\n                     int max_total_sequence_length) {\n  return num_required_pages + (!sliding_window_enabled ? new_batch_size : 0) <=\n             num_available_pages &&\n         (sliding_window_enabled ||\n          current_total_seq_len + total_input_length + 8 * new_batch_size <=\n              max_total_sequence_length);\n}\n\nBatchPrefillBaseActionObj::BatchPrefillBaseActionObj(\n    Array<Model> models, EngineConfig engine_config,\n    std::vector<tvm::ffi::json::Object> model_configs, Optional<EventTraceRecorder> trace_recorder)\n    : models_(std::move(models)),\n      engine_config_(std::move(engine_config)),\n      trace_recorder_(std::move(trace_recorder)) {\n  TVM_FFI_ICHECK_EQ(models_.size(), model_configs.size());\n  sliding_window_sizes_.reserve(models_.size());\n  for (const tvm::ffi::json::Object& model_config : model_configs) {\n    // \"-1\" means the sliding window is disabled.\n    sliding_window_sizes_.push_back(\n        json::LookupOrDefault<int64_t>(model_config, \"sliding_window_size\", -1));\n  }\n  kv_state_kind_ = models_[0]->GetMetadata().kv_state_kind;\n}\n\n/*!\n * \\brief Find one or multiple request state entries to run prefill.\n * \\param estate The engine state.\n * \\return The request entries to prefill, together with their input lengths.\n */\nstd::vector<BatchPrefillBaseActionObj::PrefillInput>\nBatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill(EngineState estate) {\n  // Preempt request state entries when decode cannot apply.\n  const std::vector<RequestStateEntry>* running_rsentries;\n  {\n    NVTXScopedRange nvtx_scope(\"BatchDecode getting requests\");\n    running_rsentries = &estate->GetRunningRequestStateEntries();\n    if (!(running_rsentries->size() <= models_[0]->GetNumAvailablePages())) {\n      // Even the decode cannot be performed.\n      // As a result, directly return without doing prefill.\n      return {};\n    }\n  }\n\n  if (estate->waiting_queue.empty()) {\n    // No request to prefill.\n    return {};\n  }\n\n  std::vector<std::vector<PrefillInput>> prefill_inputs_for_all_models;\n  prefill_inputs_for_all_models.reserve(models_.size());\n\n  int num_decode_inputs = static_cast<int>(running_rsentries->size());\n\n  // We first collect the inputs that can be prefilled for each model.\n  // Then we make a reduction to return the maximum common inputs.\n  for (int i = 0; i < static_cast<int>(models_.size()); ++i) {\n    std::vector<PrefillInput> prefill_inputs;\n    // - Try to prefill pending requests.\n    int total_input_length = 0;\n    for (const RequestStateEntry& rsentry : *running_rsentries) {\n      total_input_length += rsentry->mstates[i]->num_tokens_for_next_decode;\n    }\n    int total_required_pages = num_decode_inputs;\n    int num_available_pages;\n    int num_running_rsentries = num_decode_inputs;\n    int current_total_seq_len;\n    {\n      NVTXScopedRange nvtx_scope(\"KV cache GetNumAvailablePages\");\n      num_available_pages = models_[i]->GetNumAvailablePages();\n    }\n    {\n      NVTXScopedRange nvtx_scope(\"KV cache GetCurrentTotalSequenceLength\");\n      current_total_seq_len = models_[i]->GetCurrentTotalSequenceLength();\n    }\n\n    int num_prefill_rsentries = 0;\n    for (const Request& request : estate->waiting_queue) {\n      NVTXScopedRange nvtx_scope(\"Process request \" + request->id);\n      if (request->generation_cfg->debug_config.disagg_config.kind != DisaggRequestKind::kNone) {\n        continue;\n      }\n      RequestState rstate = estate->GetRequestState(request);\n      bool prefill_stops = false;\n      for (const RequestStateEntry& rsentry : rstate->entries) {\n        // A request state entry can be prefilled only when:\n        // - it has inputs, and\n        // - it has no parent or its parent is alive and has no remaining input.\n        if (rsentry->mstates[i]->inputs.empty() ||\n            (rsentry->parent_idx != -1 &&\n             (rstate->entries[rsentry->parent_idx]->status == RequestStateStatus::kPending ||\n              !rstate->entries[rsentry->parent_idx]->mstates[i]->inputs.empty()))) {\n          continue;\n        }\n\n        int input_length = rsentry->mstates[i]->GetInputLength();\n        int num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) /\n                                engine_config_->kv_cache_page_size;\n        bool sliding_window_enabled = sliding_window_sizes_[i] != -1;\n        int num_required_pages_under_sliding_window = std::numeric_limits<int>::max();\n        if (sliding_window_enabled) {\n          // Sliding window for model i is enabled.\n          int max_single_request_page_requirement =\n              1 + (sliding_window_sizes_[i] + engine_config_->kv_cache_page_size - 1) /\n                      engine_config_->kv_cache_page_size;\n          int num_total_prefilled_tokens = rsentry->mstates[i]->num_prefilled_tokens;\n          int parent_ptr = rsentry->parent_idx;\n          while (parent_ptr != -1) {\n            num_total_prefilled_tokens +=\n                rstate->entries[parent_ptr]->mstates[i]->num_prefilled_tokens;\n            parent_ptr = rstate->entries[parent_ptr]->parent_idx;\n          }\n\n          int num_pages_in_use = (std::min(num_total_prefilled_tokens, sliding_window_sizes_[i]) +\n                                  engine_config_->kv_cache_page_size - 1) /\n                                 engine_config_->kv_cache_page_size;\n          num_required_pages_under_sliding_window =\n              max_single_request_page_requirement - num_pages_in_use;\n          num_require_pages = std::min(num_require_pages, num_required_pages_under_sliding_window);\n          TVM_FFI_ICHECK_GE(num_require_pages, 0);\n        }\n\n        total_input_length += input_length;\n        total_required_pages += num_require_pages;\n        // - Attempt 1. Check if the entire request state entry can fit for prefill.\n        bool can_prefill = false;\n        {\n          NVTXScopedRange nvtx_scope(\"Attempt 1\");\n          for (int num_child_to_activate = rsentry->child_indices.size();\n               num_child_to_activate >= 0; --num_child_to_activate) {\n            while (!HasPrefillSpace(total_required_pages, sliding_window_enabled,\n                                    (num_running_rsentries + num_prefill_rsentries),\n                                    num_available_pages, current_total_seq_len, total_input_length,\n                                    engine_config_->max_total_sequence_length)) {\n              if (!estate->prefix_cache->TryFreeMemory()) break;\n              // Update number of available pages after memory free.\n              num_available_pages = models_[i]->GetNumAvailablePages();\n              current_total_seq_len = models_[i]->GetCurrentTotalSequenceLength();\n            }\n            if (CanPrefill(estate, num_prefill_rsentries + 1 + num_child_to_activate,\n                           total_input_length, total_required_pages, num_available_pages,\n                           current_total_seq_len, num_running_rsentries, kv_state_kind_,\n                           sliding_window_enabled)) {\n              prefill_inputs.push_back(\n                  {rsentry, input_length, num_child_to_activate, /*is_decode=*/false});\n              num_prefill_rsentries += 1 + num_child_to_activate;\n              can_prefill = true;\n              break;\n            }\n          }\n        }\n        if (can_prefill) {\n          continue;\n        }\n        total_input_length -= input_length;\n        total_required_pages -= num_require_pages;\n\n        // - Attempt 2. Check if the request state entry can partially fit by input chunking.\n        TVM_FFI_ICHECK_LE(total_input_length, engine_config_->prefill_chunk_size);\n        if (engine_config_->prefill_chunk_size - total_input_length >= input_length ||\n            engine_config_->prefill_chunk_size == total_input_length) {\n          // 1. If the input length can fit the remaining prefill chunk size,\n          // it means the failure of attempt 1 is not because of the input\n          // length being too long, and thus chunking does not help.\n          // 2. If the total input length already reaches the prefill chunk size,\n          // the current request state entry will not be able to be processed.\n          // So we can safely return in either case.\n          prefill_stops = true;\n          break;\n        }\n        input_length = engine_config_->prefill_chunk_size - total_input_length;\n        num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) /\n                            engine_config_->kv_cache_page_size;\n        if (sliding_window_enabled) {\n          // Sliding window for model i is enabled.\n          num_require_pages = std::min(num_require_pages, num_required_pages_under_sliding_window);\n          TVM_FFI_ICHECK_GE(num_require_pages, 0);\n        }\n\n        {\n          NVTXScopedRange nvtx_scope(\"Attempt 2\");\n          total_input_length += input_length;\n          total_required_pages += num_require_pages;\n          if (CanPrefill(estate, num_prefill_rsentries + 1, total_input_length,\n                         total_required_pages, num_available_pages, current_total_seq_len,\n                         num_running_rsentries, kv_state_kind_, sliding_window_enabled)) {\n            prefill_inputs.push_back({rsentry, input_length, 0, /*is_decode=*/false});\n          }\n        }\n\n        // - Prefill stops here.\n        prefill_stops = true;\n        break;\n      }\n      if (prefill_stops) {\n        break;\n      }\n    }\n    prefill_inputs_for_all_models.push_back(prefill_inputs);\n  }\n\n  // Reduce over the prefill inputs of all models.\n  TVM_FFI_ICHECK(!prefill_inputs_for_all_models.empty());\n  int num_prefill_inputs = prefill_inputs_for_all_models[0].size();\n  for (int i = 1; i < static_cast<int>(prefill_inputs_for_all_models.size()); ++i) {\n    num_prefill_inputs =\n        std::min(num_prefill_inputs, static_cast<int>(prefill_inputs_for_all_models[i].size()));\n  }\n\n  if (num_prefill_inputs == 0) {\n    return {};\n  }\n\n  // Add the decode requests to the prefill inputs if prefill mode is hybrid.\n  std::vector<PrefillInput> prefill_inputs(prefill_inputs_for_all_models[0].begin(),\n                                           prefill_inputs_for_all_models[0].end());\n  if (engine_config_->prefill_mode == PrefillMode::kHybrid) {\n    prefill_inputs.reserve(num_decode_inputs + num_prefill_inputs);\n    for (const RequestStateEntry& rsentry : *running_rsentries) {\n      prefill_inputs.push_back(\n          {rsentry, rsentry->mstates[0]->num_tokens_for_next_decode, 0, /*is_decode=*/true});\n    }\n  }\n  {\n    NVTXScopedRange nvtx_scope(\"reduction\");\n    for (int i = 1; i < static_cast<int>(prefill_inputs_for_all_models.size()); ++i) {\n      // Prefill input lengths except the last one are supposed to be the same for all models.\n      for (int j = 0; j < num_prefill_inputs - 1; ++j) {\n        TVM_FFI_ICHECK(\n            prefill_inputs_for_all_models[i][j].rsentry.same_as(prefill_inputs[j].rsentry));\n        TVM_FFI_ICHECK_EQ(prefill_inputs_for_all_models[i][j].max_prefill_length,\n                          prefill_inputs[j].max_prefill_length);\n        prefill_inputs[j].num_child_to_activate =\n            std::min(prefill_inputs[j].num_child_to_activate,\n                     prefill_inputs_for_all_models[i][j].num_child_to_activate);\n      }\n      // The input length of the last input is the minimum among all models.\n      TVM_FFI_ICHECK(prefill_inputs_for_all_models[i][num_prefill_inputs - 1].rsentry.same_as(\n          prefill_inputs[num_prefill_inputs - 1].rsentry));\n      prefill_inputs[num_prefill_inputs - 1].max_prefill_length =\n          std::min(prefill_inputs[num_prefill_inputs - 1].max_prefill_length,\n                   prefill_inputs_for_all_models[i][num_prefill_inputs - 1].max_prefill_length);\n      prefill_inputs[num_prefill_inputs - 1].num_child_to_activate =\n          std::min(prefill_inputs[num_prefill_inputs - 1].num_child_to_activate,\n                   prefill_inputs_for_all_models[i][num_prefill_inputs - 1].num_child_to_activate);\n    }\n  }\n\n  return prefill_inputs;\n}\n\nbool BatchPrefillBaseActionObj::CanPrefill(EngineState estate, int num_prefill_rsentries,\n                                           int total_input_length, int num_required_pages,\n                                           int num_available_pages, int current_total_seq_len,\n                                           int num_running_rsentries, KVStateKind kv_state_kind,\n                                           bool sliding_window_enabled) {\n  TVM_FFI_ICHECK_LE(num_running_rsentries, engine_config_->max_num_sequence);\n\n  // For RNN State, it can prefill as long as it can be instantiated.\n  if (kv_state_kind == KVStateKind::kRNNState || kv_state_kind == KVStateKind::kNone) {\n    return true;\n  }\n\n  // No exceeding of the maximum allowed requests that can\n  // run simultaneously.\n  int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable\n                        ? (estate->spec_draft_length + 1)\n                        : 1;\n  if ((num_running_rsentries + num_prefill_rsentries) * spec_factor >\n      std::min(static_cast<int64_t>(engine_config_->max_num_sequence),\n               engine_config_->prefill_chunk_size)) {\n    return false;\n  }\n\n  // NOTE: The conditions are heuristic and can be revised.\n  // Cond 1: total input length <= prefill chunk size.\n  // Cond 2: at least one decode can be performed after prefill.\n  // Cond 3: number of total tokens after 8 times of decode does not\n  // exceed the limit, where 8 is a watermark number can\n  // be configured and adjusted in the future.\n  return total_input_length <= engine_config_->prefill_chunk_size &&\n         HasPrefillSpace(num_required_pages, sliding_window_enabled,\n                         (num_running_rsentries + num_prefill_rsentries), num_available_pages,\n                         current_total_seq_len, total_input_length,\n                         engine_config_->max_total_sequence_length);\n}\n\n/*!\n * \\brief Chunk the input of the given RequestModelState for prefill\n * with regard to the provided maximum allowed prefill length.\n * Return the list of input for prefill and the total prefill length.\n * The `inputs` field of the given `mstate` will be mutated to exclude\n * the returned input.\n * \\param mstate The RequestModelState whose input data is to be chunked.\n * \\param max_prefill_length The maximum allowed prefill length for the mstate.\n * \\return The list of input for prefill and the total prefill length.\n */\nstd::pair<Array<Data>, int> BatchPrefillBaseActionObj::ChunkPrefillInputData(\n    const RequestModelState& mstate, int max_prefill_length) {\n  if (mstate->inputs.empty()) {\n    // If the request is a hybrid decode request\n    TVM_FFI_ICHECK(mstate->num_tokens_for_next_decode > 0);\n    int num_tokens = mstate->num_tokens_for_next_decode;\n    mstate->num_tokens_for_next_decode = 0;\n    std::vector<int32_t> decode_tokens;\n    decode_tokens.reserve(num_tokens);\n    for (auto begin = mstate->committed_tokens.end() - num_tokens;\n         begin != mstate->committed_tokens.end(); ++begin) {\n      decode_tokens.push_back(begin->GetTokenId());\n    }\n    return {{TokenData(decode_tokens)}, num_tokens};\n  }\n  TVM_FFI_ICHECK(!mstate->inputs.empty());\n  std::vector<Data> inputs;\n  int cum_input_length = 0;\n  inputs.reserve(mstate->inputs.size());\n  for (int i = 0; i < static_cast<int>(mstate->inputs.size()); ++i) {\n    inputs.push_back(mstate->inputs[i]);\n    int input_length = mstate->inputs[i]->GetLength();\n    cum_input_length += input_length;\n    // Case 0. the cumulative input length does not reach the maximum prefill length.\n    if (cum_input_length < max_prefill_length) {\n      continue;\n    }\n\n    // Case 1. the cumulative input length equals the maximum prefill length.\n    if (cum_input_length == max_prefill_length) {\n      if (i == static_cast<int>(mstate->inputs.size()) - 1) {\n        // - If `i` is the last input, we just copy and reset `mstate->inputs`.\n        mstate->inputs.clear();\n      } else {\n        // - Otherwise, set the new input array.\n        mstate->inputs = Array<Data>{mstate->inputs.begin() + i + 1, mstate->inputs.end()};\n      }\n      return {inputs, cum_input_length};\n    }\n\n    // Case 2. cum_input_length > max_prefill_length\n    // The input `i` itself needs chunking if it is TokenData,\n    // or otherwise it cannot be chunked.\n    Data input = mstate->inputs[i];\n    inputs.pop_back();\n    cum_input_length -= input_length;\n    const auto* token_input = input.as<TokenDataNode>();\n    if (token_input == nullptr) {\n      // Cannot chunk the input.\n      if (i != 0) {\n        mstate->inputs = Array<Data>{mstate->inputs.begin() + i, mstate->inputs.end()};\n      }\n      return {inputs, cum_input_length};\n    }\n\n    // Split the token data into two parts.\n    // Return the first part for prefill, and keep the second part.\n    int chunked_input_length = max_prefill_length - cum_input_length;\n    TVM_FFI_ICHECK_GT(input_length, chunked_input_length);\n    TokenData chunked_input(IntTuple{token_input->token_ids.begin(),\n                                     token_input->token_ids.begin() + chunked_input_length});\n    TokenData remaining_input(IntTuple{token_input->token_ids.begin() + chunked_input_length,\n                                       token_input->token_ids.end()});\n    inputs.push_back(chunked_input);\n    cum_input_length += chunked_input_length;\n    std::vector<Data> remaining_inputs{mstate->inputs.begin() + i + 1, mstate->inputs.end()};\n    remaining_inputs.insert(remaining_inputs.begin(), remaining_input);\n    mstate->inputs = remaining_inputs;\n    return {inputs, cum_input_length};\n  }\n\n  TVM_FFI_ICHECK(false) << \"Cannot reach here\";\n}\n\nvoid BatchPrefillBaseActionObj::UpdateRequestToAlive(\n    const std::vector<BatchPrefillBaseActionObj::PrefillInput>& prefill_inputs,\n    const EngineState& estate, Array<String>* request_ids,\n    std::vector<RequestState>* rstates_of_entries,\n    std::vector<RequestStateStatus>* status_before_prefill) {\n  int num_rsentries = prefill_inputs.size();\n  request_ids->reserve(num_rsentries);\n  rstates_of_entries->reserve(num_rsentries);\n  status_before_prefill->reserve(num_rsentries);\n  for (const PrefillInput& prefill_input : prefill_inputs) {\n    const RequestStateEntry& rsentry = prefill_input.rsentry;\n    const Request& request = rsentry->request;\n    RequestState request_rstate = estate->GetRequestState(request);\n    request_ids->push_back(request->id);\n    status_before_prefill->push_back(rsentry->status);\n    rsentry->status = RequestStateStatus::kAlive;\n\n    if (status_before_prefill->back() == RequestStateStatus::kPending) {\n      // - Add the request to running queue if the request state\n      // status was pending and all its request states were pending.\n      bool alive_state_existed = false;\n      for (const RequestStateEntry& rsentry_ : request_rstate->entries) {\n        if (rsentry_->status == RequestStateStatus::kAlive && !rsentry_.same_as(rsentry)) {\n          alive_state_existed = true;\n        }\n      }\n      if (!alive_state_existed) {\n        estate->running_queue.push_back(request);\n      }\n    }\n    rstates_of_entries->push_back(std::move(request_rstate));\n  }\n}\n\nstd::vector<Request> BatchPrefillBaseActionObj::RemoveProcessedRequests(\n    const std::vector<BatchPrefillBaseActionObj::PrefillInput>& prefill_inputs,\n    const EngineState& estate, const std::vector<RequestState>& rstates_of_entries) {\n  // - Remove the request from waiting queue if all its request states\n  // are now alive and have no remaining chunked inputs.\n  std::vector<Request> processed_requests;\n  int num_rsentries = prefill_inputs.size();\n  processed_requests.reserve(num_rsentries);\n  std::unordered_set<const RequestNode*> dedup_map;\n  for (int i = 0; i < num_rsentries; ++i) {\n    const RequestStateEntry& rsentry = prefill_inputs[i].rsentry;\n    if (dedup_map.find(rsentry->request.operator->()) != dedup_map.end()) {\n      continue;\n    }\n    dedup_map.insert(rsentry->request.operator->());\n    processed_requests.push_back(rsentry->request);\n\n    bool pending_state_exists = false;\n    for (const RequestStateEntry& rsentry_ : rstates_of_entries[i]->entries) {\n      if (rsentry_->status == RequestStateStatus::kPending ||\n          !rsentry_->mstates[0]->inputs.empty()) {\n        pending_state_exists = true;\n        break;\n      }\n    }\n    if (!pending_state_exists &&\n        std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), rsentry->request) !=\n            estate->waiting_queue.end()) {\n      auto it =\n          std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), rsentry->request);\n      if (it != estate->waiting_queue.end()) {\n        estate->waiting_queue.erase(it);\n      }\n    }\n  }\n  return processed_requests;\n}\n\nvoid BatchPrefillBaseActionObj::UpdateRequestStateEntriesWithSampleResults(\n    const std::vector<RequestStateEntry>& rsentries_for_sample,\n    const std::vector<bool>& rsentry_activated, const std::vector<SampleResult>& sample_results) {\n  auto tnow = std::chrono::high_resolution_clock::now();\n  for (int i = 0; i < static_cast<int>(rsentries_for_sample.size()); ++i) {\n    // If the request is a hybrid decode request\n    if (rsentries_for_sample[i]->status == RequestStateStatus::kAlive &&\n        rsentries_for_sample[i]->child_indices.empty() &&\n        rsentries_for_sample[i]->mstates[0]->inputs.empty()) {\n      for (const RequestModelState& mstate : rsentries_for_sample[i]->mstates) {\n        TVM_FFI_ICHECK(!mstate->require_retokenization_in_next_decode);\n        mstate->CommitToken(sample_results[i]);\n        // live update the output metrics\n        rsentries_for_sample[i]->rstate->metrics.completion_tokens += 1;\n        rsentries_for_sample[i]->rstate->metrics.prefill_end_time_point = tnow;\n      }\n      continue;\n    }\n\n    // Update all model states of the request state entry.\n    for (const RequestModelState& mstate : rsentries_for_sample[i]->mstates) {\n      mstate->CommitToken(sample_results[i]);\n      if (!rsentry_activated[i]) {\n        // When the child rsentry is not activated,\n        // add the sampled token as an input of the mstate for prefill.\n        mstate->inputs.push_back(TokenData(std::vector<int64_t>{sample_results[i].GetTokenId()}));\n      }\n    }\n    // prefill has finished\n    if (rsentries_for_sample[i]->mstates[0]->committed_tokens.size() == 1) {\n      TVM_FFI_ICHECK(rsentries_for_sample[i]->rstate != nullptr);\n      rsentries_for_sample[i]->rstate->metrics.prefill_end_time_point = tnow;\n    }\n  }\n}\n\nstd::vector<int32_t> BatchPrefillBaseActionObj::GetConcatPrefillInputData(\n    const RequestModelState& mstate) {\n  std::vector<int32_t> tokens;\n  for (Data data : mstate->inputs) {\n    if (const TokenDataNode* token_data = data.as<TokenDataNode>()) {\n      tokens.reserve(tokens.size() + token_data->GetLength());\n      tokens.insert(tokens.end(), token_data->token_ids.begin(), token_data->token_ids.end());\n    } else {\n      return {};\n    }\n  }\n  return tokens;\n}\n\nvoid BatchPrefillBaseActionObj::PopPrefillInputData(const RequestModelState& mstate,\n                                                    size_t num_tokens) {\n  while (mstate->inputs[0]->GetLength() <= num_tokens) {\n    num_tokens -= mstate->inputs[0]->GetLength();\n    mstate->inputs.erase(mstate->inputs.begin());\n  }\n  if (num_tokens) {\n    const TokenDataNode* token_data = mstate->inputs[0].as<TokenDataNode>();\n    std::vector<int32_t> tokens;\n    tokens.reserve(token_data->GetLength() - num_tokens);\n    tokens.insert(tokens.begin(), token_data->token_ids.begin() + num_tokens,\n                  token_data->token_ids.end());\n    mstate->inputs.erase(mstate->inputs.begin());\n    mstate->inputs.insert(mstate->inputs.begin(), TokenData(tokens));\n  }\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/engine_actions/batch_prefill_base.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_actions/batch_prefill_base.h\n */\n\n#include <tvm/runtime/nvtx.h>\n\n#include \"../config.h\"\n#include \"../model.h\"\n#include \"action.h\"\n#include \"action_commons.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\n/*!\n * \\brief The base action of that prefills requests in the `waiting_queue` of\n * the engine state.\n */\nclass BatchPrefillBaseActionObj : public EngineActionObj {\n protected:\n  /*! \\brief The class of request state entry and its maximum allowed length for prefill. */\n  struct PrefillInput {\n    RequestStateEntry rsentry;\n    int max_prefill_length = 0;\n    int num_child_to_activate = 0;\n    bool is_decode = false;\n  };\n\n  BatchPrefillBaseActionObj(Array<Model> models, EngineConfig engine_config,\n                            std::vector<tvm::ffi::json::Object> model_configs,\n                            Optional<EventTraceRecorder> trace_recorder);\n\n  /*!\n   * \\brief Find one or multiple request state entries to run prefill.\n   * \\param estate The engine state.\n   * \\return The request entries to prefill, together with their input lengths.\n   */\n  std::vector<PrefillInput> GetRequestStateEntriesToPrefill(EngineState estate);\n\n  /*! \\brief Check if the input requests can be prefilled under conditions. */\n  bool CanPrefill(EngineState estate, int num_prefill_rsentries, int total_input_length,\n                  int num_required_pages, int num_available_pages, int current_total_seq_len,\n                  int num_running_rsentries, KVStateKind kv_state_kind,\n                  bool sliding_window_enabled);\n\n  /*!\n   * \\brief Chunk the input of the given RequestModelState for prefill\n   * with regard to the provided maximum allowed prefill length.\n   * Return the list of input for prefill and the total prefill length.\n   * The `inputs` field of the given `mstate` will be mutated to exclude\n   * the returned input.\n   * \\param mstate The RequestModelState whose input data is to be chunked.\n   * \\param max_prefill_length The maximum allowed prefill length for the mstate.\n   * \\return The list of input for prefill and the total prefill length.\n   */\n  std::pair<Array<Data>, int> ChunkPrefillInputData(const RequestModelState& mstate,\n                                                    int max_prefill_length);\n\n  /*!\n   * \\brief Update status of request states from pending to alive and collect request state entries\n   * from the prefill input.\n   * \\param prefill_inputs The prefill input.\n   * \\param estate The engine state.\n   * \\param[out] request_ids The array to store the request ids of the request state entries.\n   * \\param[out] rstates_of_entries The vector to store the request state entries.\n   * \\param[out] status_before_prefill The vector to store the status of the request state entries\n   * before prefill.\n   */\n  void UpdateRequestToAlive(const std::vector<PrefillInput>& prefill_inputs,\n                            const EngineState& estate, Array<String>* request_ids,\n                            std::vector<RequestState>* rstates_of_entries,\n                            std::vector<RequestStateStatus>* status_before_prefill);\n\n  /*!\n   * \\brief Remove the request from waiting queue if all its request states are now alive and have\n   * no remaining chunked inputs.\n   * \\param prefill_inputs The prefill input.\n   * \\param estate The engine state.\n   * \\param rstates_of_entries The request state entries for each prefill input.\n   * \\return The processed requests.\n   */\n  std::vector<Request> RemoveProcessedRequests(const std::vector<PrefillInput>& prefill_inputs,\n                                               const EngineState& estate,\n                                               const std::vector<RequestState>& rstates_of_entries);\n\n  /*!\n   * \\brief Update the committed tokens of states. If a request is first-time prefilled, set the\n   * prefill finish time.\n   * \\param rsentries_for_sample The request state entries for sample.\n   * \\param rsentry_activated The activation status of the request state entries.\n   * \\param sample_results The sample results.\n   */\n  void UpdateRequestStateEntriesWithSampleResults(\n      const std::vector<RequestStateEntry>& rsentries_for_sample,\n      const std::vector<bool>& rsentry_activated, const std::vector<SampleResult>& sample_results);\n\n  /*!\n   * \\brief Get the concatenated IntTuple of RequestModelState input data, return empty IntTuple if\n   * there is untokenized data.\n   * \\param mstate The RequestModelState whose input data is to be concatenated.\n   * \\return The concatenate IntTuple.\n   */\n  std::vector<int32_t> GetConcatPrefillInputData(const RequestModelState& mstate);\n\n  /*!\n   * \\brief Pop the prefix tokens of the RequestModelState input data array.\n   * \\param mstate The RequestModelState to be popped.\n   * \\param num_tokens The number of prefix tokens to be popped.\n   */\n  void PopPrefillInputData(const RequestModelState& mstate, size_t num_tokens);\n\n  /*!\n   * \\brief Match the request state entry with prefix cache, to skip prefilling common prefix\n   * tokens. If the request state entry is not added to KVCache yet, this method will add/fork the\n   * request in the KVCache, depending on the matching result from prefix cache.\n   * \\param estate The engine state.\n   * \\param[in, out] input The prefill input to be matched and updated.\n   * \\return The matched length in prefix cache.\n   */\n  virtual int MatchPrefixCache(EngineState estate, PrefillInput* input) = 0;\n\n  /*! \\brief The models to run prefill in. */\n  Array<Model> models_;\n  /*! \\brief The engine config. */\n  EngineConfig engine_config_;\n  /*! \\brief The KV state kind. */\n  KVStateKind kv_state_kind_;\n  /*! \\brief The sliding window size of each model. */\n  std::vector<int> sliding_window_sizes_;\n  /*! \\brief Event trace recorder. */\n  Optional<EventTraceRecorder> trace_recorder_;\n};\n\n/*!\n * \\brief A utility function to check whether there is enough spare space in\n * KV cache for the number of required pages and total input length.\n */\nbool HasPrefillSpace(int num_required_pages, bool sliding_window_enabled, int new_batch_size,\n                     int num_available_pages, int current_total_seq_len, int total_input_length,\n                     int max_total_sequence_length);\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/engine_actions/batch_verify.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_actions/batch_verify.cc\n */\n\n#include <tvm/runtime/threading_backend.h>\n\n#include <cmath>\n#include <exception>\n#include <numeric>\n\n#include \"../../support/random.h\"\n#include \"../config.h\"\n#include \"../model.h\"\n#include \"../sampler/sampler.h\"\n#include \"action.h\"\n#include \"action_commons.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\n/*!\n * \\brief The action that runs verification for requests in the\n * `running_queue` of engine state. Preempt low-priority requests\n * accordingly when it is impossible to decode all the running requests.\n */\nclass BatchVerifyActionObj : public EngineActionObj {\n public:\n  explicit BatchVerifyActionObj(Array<Model> models, LogitProcessor logit_processor,\n                                Sampler sampler, std::vector<ModelWorkspace> model_workspaces,\n                                DraftTokenWorkspaceManager draft_token_workspace_manager,\n                                EngineConfig engine_config,\n                                Optional<EventTraceRecorder> trace_recorder)\n      : models_(std::move(models)),\n        logit_processor_(std::move(logit_processor)),\n        sampler_(std::move(sampler)),\n        model_workspaces_(std::move(model_workspaces)),\n        draft_token_workspace_manager_(std::move(draft_token_workspace_manager)),\n        engine_config_(std::move(engine_config)),\n        trace_recorder_(std::move(trace_recorder)),\n        rng_(RandomGenerator::GetInstance()) {}\n\n  Array<Request> Step(EngineState estate) final {\n    // - Only run spec decode when there are two models (llm+ssm) and >=1 running requests.\n    if (models_.size() != 2 || estate->running_queue.empty()) {\n      return {};\n    }\n\n    const auto& [rsentries, verify_lengths, total_verify_length] = GetDraftsToVerify(estate);\n    TVM_FFI_ICHECK_EQ(rsentries.size(), verify_lengths.size());\n    if (rsentries.empty()) {\n      return {};\n    }\n\n    auto tstart = std::chrono::high_resolution_clock::now();\n    int num_rsentries = rsentries.size();\n    Array<String> request_ids =\n        rsentries.Map([](const RequestStateEntry& rstate) { return rstate->request->id; });\n\n    // - Get embedding and run verify.\n    std::vector<int64_t> request_internal_ids;\n    std::vector<int32_t> all_tokens_to_verify;\n    Array<RequestModelState> verify_request_mstates;\n    Array<RequestModelState> draft_request_mstates;\n    Array<GenerationConfig> generation_cfg;\n    Array<GenerationConfig> generation_cfg_for_top_p_norm;\n    std::vector<RandomGenerator*> rngs;\n    std::vector<std::vector<SampleResult>> draft_output_tokens;\n    std::vector<int64_t> token_tree_parent_ptr;\n    std::vector<std::vector<int>> draft_token_indices;\n    token_tree_parent_ptr.reserve(total_verify_length);\n    request_internal_ids.reserve(num_rsentries);\n    all_tokens_to_verify.reserve(total_verify_length);\n    draft_token_indices.reserve(num_rsentries);\n    verify_request_mstates.reserve(num_rsentries);\n    draft_request_mstates.reserve(num_rsentries);\n    rngs.reserve(num_rsentries);\n    generation_cfg.reserve(num_rsentries);\n    generation_cfg_for_top_p_norm.reserve(total_verify_length);\n    draft_output_tokens.reserve(num_rsentries);\n    draft_token_slots_.clear();\n    for (int i = 0; i < num_rsentries; ++i) {\n      RequestModelState verify_mstate = rsentries[i]->mstates[verify_model_id_];\n      RequestModelState draft_mstate = rsentries[i]->mstates[draft_model_id_];\n      request_internal_ids.push_back(verify_mstate->internal_id);\n      TVM_FFI_ICHECK(!verify_lengths.empty());\n      TVM_FFI_ICHECK_EQ(verify_lengths[i], draft_mstate->draft_output_tokens.size() + 1);\n      TVM_FFI_ICHECK_EQ(verify_lengths[i], draft_mstate->draft_token_slots.size() + 1);\n      // the last committed token + all the draft tokens.\n      draft_token_slots_.push_back(0);  // placeholder for the last committed token\n      all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().GetTokenId());\n      token_tree_parent_ptr.push_back(-1);\n      generation_cfg_for_top_p_norm.push_back(rsentries[i]->request->generation_cfg);\n      std::vector<int> cur_draft_token_indices;\n      cur_draft_token_indices.resize(draft_mstate->draft_output_tokens.size() + 1);\n      std::iota(cur_draft_token_indices.begin(), cur_draft_token_indices.end(), -1);\n      for (int j = 0; j < static_cast<int>(draft_mstate->draft_output_tokens.size()); ++j) {\n        all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].GetTokenId());\n        draft_token_slots_.push_back(draft_mstate->draft_token_slots[j]);\n        token_tree_parent_ptr.push_back(draft_mstate->draft_token_parent_idx[j] + 1);\n        generation_cfg_for_top_p_norm.push_back(rsentries[i]->request->generation_cfg);\n      }\n      draft_token_indices.emplace_back(std::move(cur_draft_token_indices));\n      verify_request_mstates.push_back(verify_mstate);\n      draft_request_mstates.push_back(draft_mstate);\n      generation_cfg.push_back(rsentries[i]->request->generation_cfg);\n      rngs.push_back(&rsentries[i]->rng);\n      draft_output_tokens.push_back(draft_mstate->draft_output_tokens);\n    }\n    Tensor draft_probs_on_device = models_[draft_model_id_]->GatherDraftProbs(\n        model_workspaces_[verify_model_id_].draft_probs_storage, draft_token_slots_,\n        &model_workspaces_[verify_model_id_].draft_probs);\n\n    RECORD_EVENT(trace_recorder_, request_ids, \"start verify embedding\");\n    ObjectRef embeddings = models_[verify_model_id_]->TokenEmbed(\n        {IntTuple{all_tokens_to_verify.begin(), all_tokens_to_verify.end()}});\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish verify embedding\");\n\n    RECORD_EVENT(trace_recorder_, request_ids, \"start verify\");\n    Tensor logits = models_[verify_model_id_]->BatchVerify(embeddings, request_internal_ids,\n                                                           verify_lengths, token_tree_parent_ptr);\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish verify\");\n    TVM_FFI_ICHECK_EQ(logits->ndim, 3);\n    TVM_FFI_ICHECK_EQ(logits->shape[0], 1);\n    TVM_FFI_ICHECK_EQ(logits->shape[1], total_verify_length);\n\n    // - Update logits.\n    std::vector<int> cum_verify_lengths = {0};\n    cum_verify_lengths.reserve(num_rsentries + 1);\n    for (int i = 0; i < num_rsentries; ++i) {\n      cum_verify_lengths.push_back(cum_verify_lengths.back() + verify_lengths[i]);\n    }\n    logits = logits.CreateView({total_verify_length, logits->shape[2]}, logits->dtype);\n    logit_processor_->InplaceUpdateLogits(logits, generation_cfg, verify_request_mstates,\n                                          request_ids, &cum_verify_lengths, &draft_request_mstates,\n                                          &draft_token_indices);\n\n    // - Compute probability distributions.\n    Tensor probs_on_device = logit_processor_->ComputeProbsFromLogits(\n        logits, generation_cfg, request_ids, &cum_verify_lengths);\n\n    // - Commit the prefix cache changes from previous round of action.\n    // Note: we commit prefix cache changes here to overlap this commit with the GPU execution.\n    estate->prefix_cache->CommitSequenceExtention();\n\n    // Fill range [0, total_verify_length) into `sample_indices`.\n    std::vector<int> sample_indices(total_verify_length);\n    std::iota(sample_indices.begin(), sample_indices.end(), 0);\n    Tensor renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(\n        probs_on_device, sample_indices, request_ids, generation_cfg_for_top_p_norm);\n    auto [sample_results_arr, last_accepted_tree_node_verify_model] =\n        sampler_->BatchVerifyDraftTokensWithProbAfterTopP(\n            renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs,\n            draft_output_tokens, token_tree_parent_ptr, draft_probs_on_device);\n    TVM_FFI_ICHECK_EQ(sample_results_arr.size(), num_rsentries);\n\n    // We collect the requests whose drafts are fully accepted.\n    // When a request's draft is fully accepted, there is an extra token proposed\n    // by the draft model but not added into the draft model's KV cache.\n    // In this case, an additional batch decode step is needed for these requests.\n    std::vector<int64_t> fully_accepted_rsentries;\n    std::vector<int64_t> verify_model_seq_internal_ids;\n    std::vector<int64_t> draft_model_seq_internal_ids;\n    fully_accepted_rsentries.reserve(num_rsentries);\n    verify_model_seq_internal_ids.reserve(num_rsentries);\n    draft_model_seq_internal_ids.reserve(num_rsentries);\n\n    // The index of the last accepted tree node in the draft model. This is different from the\n    // last accepted tree node in the verify model because the first round of draft does not\n    // use tree attention.\n    std::vector<int64_t> last_accepted_tree_node_draft_model;\n    last_accepted_tree_node_draft_model.reserve(num_rsentries);\n    for (int i = 0; i < num_rsentries; ++i) {\n      const std::vector<SampleResult>& sample_results = sample_results_arr[i];\n      int accept_length = sample_results.size();\n      for (SampleResult sample_result : sample_results) {\n        rsentries[i]->mstates[verify_model_id_]->CommitToken(sample_result);\n        rsentries[i]->mstates[draft_model_id_]->CommitToken(sample_result);\n      }\n      // Metrics update\n      // live update the output metrics\n      rsentries[i]->rstate->metrics.completion_tokens += accept_length;\n      estate->metrics.spec_decode.Update(cum_verify_lengths[i + 1] - cum_verify_lengths[i],\n                                         accept_length);\n      if (engine_config_->spec_tree_width == 1) {\n        // The roll back is needed for the chain draft case.\n        int rollback_length =\n            std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length, 0);\n        if (rollback_length > 0) {\n          // The last accepted token is not yet added into the draft model.\n          // Therefore, the rollback length for the draft model is one less.\n          models_[draft_model_id_]->PopNFromKVCache(\n              rsentries[i]->mstates[draft_model_id_]->internal_id, rollback_length - 1);\n        }\n      }\n      // Commit accepted tokens to the \"verify_model\", rollback kv cache\n      // in the \"draft_model\".\n      // NOTE: when number of small models is more than 1 (in the future),\n      // it is possible to re-compute prefill for the small models.\n      verify_model_seq_internal_ids.push_back(rsentries[i]->mstates[verify_model_id_]->internal_id);\n      draft_model_seq_internal_ids.push_back(rsentries[i]->mstates[draft_model_id_]->internal_id);\n      int last_accepted = last_accepted_tree_node_verify_model[i] -\n                          1;  // minus one to get the index in the draft tokens\n      if (last_accepted >= 0 &&\n          rsentries[i]->mstates[draft_model_id_]->draft_token_first_child_idx[last_accepted] ==\n              -1) {  // minus one to get the index in the draft tokens\n        // is leaf node, fully accepted\n        last_accepted_tree_node_draft_model.push_back(\n            rsentries[i]->mstates[draft_model_id_]->draft_token_parent_idx[last_accepted]);\n        fully_accepted_rsentries.push_back(i);\n      } else {\n        last_accepted_tree_node_draft_model.push_back(last_accepted);\n      }\n    }\n    models_[verify_model_id_]->CommitAcceptedTokenTreeNodesToKVCache(\n        verify_model_seq_internal_ids,\n        std::vector<int64_t>{last_accepted_tree_node_verify_model.begin(),\n                             last_accepted_tree_node_verify_model.end()});\n    if (engine_config_->spec_tree_width > 1) {\n      models_[draft_model_id_]->CommitAcceptedTokenTreeNodesToKVCache(\n          draft_model_seq_internal_ids, last_accepted_tree_node_draft_model);\n    }\n\n    if (!fully_accepted_rsentries.empty()) {\n      // - Run a step of batch decode for requests whose drafts are fully accepted.\n      // When a request's draft is fully accepted, there is an extra token proposed\n      // by the draft model but not added into the draft model's KV cache.\n      // In this case, an additional batch decode step is needed for these requests.\n      std::vector<int> input_tokens;\n      std::vector<int64_t> fully_accepted_request_internal_ids;\n      input_tokens.reserve(fully_accepted_rsentries.size());\n      fully_accepted_request_internal_ids.reserve(fully_accepted_rsentries.size());\n      for (int rsentry_id : fully_accepted_rsentries) {\n        int num_committed_tokens =\n            rsentries[rsentry_id]->mstates[verify_model_id_]->committed_tokens.size();\n        // When a request's draft is fully accepted, an additional new token is sampled.\n        // So the token needed to fill in the draft model is the committed_token[-2].\n        TVM_FFI_ICHECK_GE(num_committed_tokens, 2);\n        input_tokens.push_back(rsentries[rsentry_id]\n                                   ->mstates[verify_model_id_]\n                                   ->committed_tokens[num_committed_tokens - 2]\n                                   .GetTokenId());\n        fully_accepted_request_internal_ids.push_back(\n            rsentries[rsentry_id]->mstates[draft_model_id_]->internal_id);\n      }\n      // - Compute embeddings.\n      ObjectRef embeddings = models_[draft_model_id_]->TokenEmbed(\n          {IntTuple{input_tokens.begin(), input_tokens.end()}});\n      // - Invoke model decode.\n      Tensor logits =\n          models_[draft_model_id_]->BatchDecode(embeddings, fully_accepted_request_internal_ids);\n      // - We explicitly synchronize to avoid the input tokens getting overriden in the\n      // next runs of BatchDecode.\n      // This is because we do not do sample for this round of batch decode.\n      DeviceAPI::Get(logits->device)->StreamSync(logits->device, nullptr);\n    }\n\n    // clear the draft model state entries\n    for (int i = 0; i < num_rsentries; ++i) {\n      rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(&draft_token_slots_);\n      draft_token_workspace_manager_->FreeSlots(draft_token_slots_);\n      // reset num_tokens_for_next_decode to 1\n      rsentries[i]->mstates[verify_model_id_]->num_tokens_for_next_decode = 1;\n      rsentries[i]->mstates[draft_model_id_]->num_tokens_for_next_decode = 1;\n    }\n\n    auto tend = std::chrono::high_resolution_clock::now();\n    double elapsed_time = static_cast<double>((tend - tstart).count()) / 1e9;\n    estate->metrics.engine_decode_time_sum += elapsed_time;\n    estate->metrics.UpdateVerifyTimeByBatchSize(total_verify_length, elapsed_time);\n\n    return estate->running_queue;\n  }\n\n private:\n  struct DraftRequestStateEntries {\n    /*! \\brief The request state entries to verify. */\n    Array<RequestStateEntry> draft_rsentries;\n    /*! \\brief The length to verify for each request state. */\n    std::vector<int> verify_lengths;\n    /*! \\brief The total draft length. */\n    int total_verify_length;\n  };\n\n  /*!\n   * \\brief Decide whether to run verify for the draft of each request.\n   * \\param estate The engine state.\n   * \\return The drafts to verify, together with their respective\n   * state and input length.\n   */\n  DraftRequestStateEntries GetDraftsToVerify(EngineState estate) {\n    std::vector<int> verify_lengths;\n    int total_verify_length = 0;\n    int total_required_pages = 0;\n    int num_available_pages = models_[verify_model_id_]->GetNumAvailablePages();\n\n    // Preempt the request state entries that cannot fit the large model for verification.\n    std::vector<RequestStateEntry> init_running_rsentries = estate->GetRunningRequestStateEntries();\n    std::vector<int> num_page_requirement;\n    num_page_requirement.reserve(init_running_rsentries.size());\n    std::vector<RequestStateEntry> running_rsentries;\n    running_rsentries.reserve(init_running_rsentries.size());\n    for (const RequestStateEntry& rsentry : init_running_rsentries) {\n      int draft_length = rsentry->mstates[draft_model_id_]->draft_output_tokens.size();\n      if (draft_length == 0) {\n        continue;\n      }\n      running_rsentries.push_back(rsentry);\n      int num_require_pages = (draft_length + engine_config_->kv_cache_page_size - 1) /\n                              engine_config_->kv_cache_page_size;\n      verify_lengths.push_back(draft_length + 1);\n      num_page_requirement.push_back(num_require_pages);\n      total_verify_length += draft_length + 1;\n      total_required_pages += num_require_pages;\n    }\n    while (!CanVerify(total_required_pages)) {\n      if (estate->prefix_cache->TryFreeMemory()) continue;\n      RequestStateEntry preempted = PreemptLastRunningRequestStateEntry(\n          estate, models_, draft_token_workspace_manager_, trace_recorder_);\n      if (preempted.same_as(running_rsentries.back())) {\n        total_verify_length -= verify_lengths.back();\n        total_required_pages -= num_page_requirement.back();\n        verify_lengths.pop_back();\n        num_page_requirement.pop_back();\n        running_rsentries.pop_back();\n      }\n    }\n    TVM_FFI_ICHECK_LE(total_verify_length,\n                      std::min(static_cast<int64_t>(engine_config_->max_num_sequence),\n                               engine_config_->prefill_chunk_size))\n        << total_verify_length << \" \" << engine_config_->max_num_sequence;\n\n    return {running_rsentries, verify_lengths, total_verify_length};\n  }\n\n  bool CanVerify(int num_required_pages) {\n    int num_available_pages = models_[0]->GetNumAvailablePages();\n    return num_required_pages <= num_available_pages;\n  }\n\n  /*!\n   * \\brief The model to run decode in. When there are multiple\n   * models, the `Step` function of the created action will not take effect.\n   */\n  Array<Model> models_;\n  /*! \\brief The logit processor. */\n  LogitProcessor logit_processor_;\n  /*! \\brief The sampler to sample new tokens. */\n  Sampler sampler_;\n  /*! \\brief The model workspaces. */\n  std::vector<ModelWorkspace> model_workspaces_;\n  /*! \\brief The draft token workspace manager. */\n  DraftTokenWorkspaceManager draft_token_workspace_manager_;\n  /*! \\brief The engine config. */\n  EngineConfig engine_config_;\n  /*! \\brief Event trace recorder. */\n  Optional<EventTraceRecorder> trace_recorder_;\n  /*! \\brief Random number generator. */\n  RandomGenerator& rng_;\n  /*! \\brief The ids of verify/draft models. */\n  const int verify_model_id_ = 0;\n  const int draft_model_id_ = 1;\n  const float eps_ = 1e-5;\n  /*! \\brief Temporary buffer to store the slots of the current draft tokens */\n  std::vector<int> draft_token_slots_;\n};\n\nEngineAction EngineAction::BatchVerify(Array<Model> models, LogitProcessor logit_processor,\n                                       Sampler sampler,\n                                       std::vector<ModelWorkspace> model_workspaces,\n                                       DraftTokenWorkspaceManager draft_token_workspace_manager,\n                                       EngineConfig engine_config,\n                                       Optional<EventTraceRecorder> trace_recorder) {\n  return EngineAction(tvm::ffi::make_object<BatchVerifyActionObj>(\n      std::move(models), std::move(logit_processor), std::move(sampler),\n      std::move(model_workspaces), std::move(draft_token_workspace_manager),\n      std::move(engine_config), std::move(trace_recorder)));\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/engine_actions/disagg_prepare_recv.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_actions/new_request_prefill.cc\n */\n\n#include <optional>\n\n#include \"../../support/utils.h\"\n#include \"../sampler/sampler.h\"\n#include \"batch_prefill_base.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\n/*!\n * \\brief The action that runs prefill preparation in disaggregation system.\n * It picks a new request, reserve its KV data locations, and returns the\n * KV data locations and the matched prefix length in prefix cache.\n */\nclass DisaggPrepareReceiveActionObj : public BatchPrefillBaseActionObj {\n public:\n  explicit DisaggPrepareReceiveActionObj(Array<Model> models, EngineConfig engine_config,\n                                         std::vector<tvm::ffi::json::Object> model_configs,\n                                         Optional<EventTraceRecorder> trace_recorder,\n                                         FRequestStreamCallback request_stream_callback)\n      : BatchPrefillBaseActionObj(std::move(models), std::move(engine_config),\n                                  std::move(model_configs), std::move(trace_recorder)),\n        request_stream_callback_(std::move(request_stream_callback)) {\n    TVM_FFI_ICHECK(kv_state_kind_ == KVStateKind::kKVCache)\n        << \"Only PagedKVCache supports prefill preparation and KV migration\";\n  }\n\n  Array<Request> Step(EngineState estate) final {\n    std::vector<Request> processed_requests;\n\n    // - Find the requests in `waiting_queue` that can prefill in this step.\n    std::optional<PrefillInput> prefill_input_opt;\n    while (true) {\n      prefill_input_opt = GetRequestStateEntriesToPrefill(estate);\n      if (!prefill_input_opt.has_value()) {\n        break;\n      }\n      PrefillInput prefill_input = prefill_input_opt.value();\n      int prefix_matched_length = 0;\n      Request request = prefill_input.rsentry->request;\n      processed_requests.push_back(request);\n      int total_input_length = 0;\n      for (const Data& data : request->inputs) {\n        total_input_length += data->GetLength();\n      }\n\n      {\n        NVTXScopedRange nvtx_scope(\"DisaggPrepareReceive matching prefix\");\n        prefix_matched_length = MatchPrefixCache(estate, &prefill_input);\n      }\n\n      auto tstart = std::chrono::high_resolution_clock::now();\n\n      // - Update status of request states from pending to alive.\n      Array<String> request_ids;\n      std::vector<RequestState> rstates_of_entries;\n      std::vector<RequestStateStatus> status_before_prefill;\n      UpdateRequestToAlive({prefill_input}, estate, &request_ids, &rstates_of_entries,\n                           &status_before_prefill);\n      // \"UpdateRequestToAlive\" may add the request to the engine's running request queue.\n      // We erase it since it's pending for the prefill instance to send the KV data over.\n      if (!estate->running_queue.empty() && estate->running_queue.back().same_as(request)) {\n        estate->running_queue.pop_back();\n      }\n\n      // - Add the sequence to each model.\n      int prefill_length = -1;\n      Tensor logits_for_sample{nullptr};\n      std::vector<IntTuple> kv_append_metadata;\n      kv_append_metadata.reserve(models_.size());\n      for (int model_id = 0; model_id < static_cast<int>(models_.size()); ++model_id) {\n        const RequestStateEntry& rsentry = prefill_input.rsentry;\n        RequestModelState mstate = rsentry->mstates[model_id];\n        Array<Data> input_data = mstate->inputs;\n        mstate->inputs.clear();\n        int input_length = prefill_input.max_prefill_length;\n        if (prefill_length == -1) {\n          prefill_length = input_length;\n        } else {\n          TVM_FFI_ICHECK_EQ(prefill_length, input_length);\n        }\n        mstate->num_prefilled_tokens += input_length;\n\n        TVM_FFI_ICHECK(mstate->draft_output_tokens.empty());\n        TVM_FFI_ICHECK(mstate->draft_token_slots.empty());\n        if (status_before_prefill[0] == RequestStateStatus::kPending &&\n            !estate->prefix_cache->HasSequence(mstate->internal_id)) {\n          // Add the sequence to the model, or fork the sequence from its parent.\n          // If the sequence is already in prefix cache, it has also been added/forked in the\n          // KVCache.\n          if (rsentry->parent_idx == -1) {\n            models_[model_id]->AddNewSequence(mstate->internal_id);\n          } else {\n            models_[model_id]->ForkSequence(\n                rstates_of_entries[0]->entries[rsentry->parent_idx]->mstates[model_id]->internal_id,\n                mstate->internal_id);\n          }\n          // Enable sliding window for the sequence if it is not a parent.\n          if (rsentry->child_indices.empty()) {\n            models_[model_id]->EnableSlidingWindowForSeq(mstate->internal_id);\n          }\n        }\n\n        // Record the of the prefilled inputs for prefix cache update.\n        for (int j = 0; j < static_cast<int>(input_data.size()); ++j) {\n          if (!model_id && !prefill_input.is_decode) {\n            mstate->prefilled_inputs.push_back(input_data[j]);\n          }\n        }\n\n        int64_t request_internal_id = mstate->internal_id;\n        RECORD_EVENT(trace_recorder_, request_ids, \"start prefill\");\n        IntTuple compressed_kv_append_metadata = {0};\n        if (prefill_length > 0) {\n          compressed_kv_append_metadata =\n              models_[model_id]->DisaggPrepareKVRecv(request_internal_id, prefill_length);\n        }\n        kv_append_metadata.push_back(compressed_kv_append_metadata);\n        RECORD_EVENT(trace_recorder_, request_ids, \"finish prefill\");\n      }\n\n      // - Commit the prefix cache changes from previous round of action.\n      // Note: we commit prefix cache changes here to overlap this commit with the GPU execution.\n      estate->prefix_cache->CommitSequenceExtention();\n\n      auto tend = std::chrono::high_resolution_clock::now();\n\n      // - Remove the request from the waiting queue.\n      auto it_request =\n          std::find(estate->waiting_queue.begin(), estate->waiting_queue.end(), request);\n      TVM_FFI_ICHECK(it_request != estate->waiting_queue.end());\n      estate->waiting_queue.erase(it_request);\n\n      {\n        NVTXScopedRange nvtx_scope(\"Call request stream callback\");\n        tvm::ffi::json::Object response_body;\n        response_body.Set(\"prompt_length\", static_cast<int64_t>(total_input_length));\n        response_body.Set(\"prefix_matched_length\", static_cast<int64_t>(prefix_matched_length));\n        // We further flatten the metadata array of all models into a single array.\n        tvm::ffi::json::Array kv_append_metadata_arr;\n        for (const IntTuple& compressed_kv_append_metadata : kv_append_metadata) {\n          for (int64_t value : compressed_kv_append_metadata) {\n            kv_append_metadata_arr.push_back(value);\n          }\n          TVM_FFI_ICHECK(!compressed_kv_append_metadata.empty());\n          int num_segments = compressed_kv_append_metadata[0];\n          TVM_FFI_ICHECK_EQ(compressed_kv_append_metadata.size(), num_segments * 2 + 1);\n          int transmission_length = 0;\n          for (int i = 0; i < num_segments; ++i) {\n            transmission_length += compressed_kv_append_metadata[i * 2 + 2];\n          }\n          TVM_FFI_ICHECK_EQ(transmission_length, prefill_length);\n        }\n\n        response_body.Set(\n            \"kv_append_metadata\",\n            Base64Encode(std::string(tvm::ffi::json::Stringify(kv_append_metadata_arr))));\n\n        tvm::ffi::json::Object usage;\n        usage.Set(\"prompt_tokens\", static_cast<int64_t>(0));\n        usage.Set(\"completion_tokens\", static_cast<int64_t>(0));\n        usage.Set(\"total_tokens\", static_cast<int64_t>(0));\n        usage.Set(\"extra\", response_body);\n        RequestStreamOutput stream_output =\n            RequestStreamOutput::Usage(request->id, std::string(tvm::ffi::json::Stringify(usage)));\n        // - Invoke the stream callback function once for all collected requests.\n        request_stream_callback_(Array<RequestStreamOutput>{stream_output});\n      }\n    }\n\n    for (const Request& request : processed_requests) {\n      TVM_FFI_ICHECK(std::find(estate->running_queue.begin(), estate->running_queue.end(),\n                               request) == estate->running_queue.end());\n    }\n    return {processed_requests};\n  }\n\n private:\n  // Mimicked from BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill\n  std::optional<PrefillInput> GetRequestStateEntriesToPrefill(EngineState estate) {\n    const std::vector<RequestStateEntry>* running_rsentries;\n    {\n      NVTXScopedRange nvtx_scope(\"BatchDecode getting requests\");\n      running_rsentries = &estate->GetRunningRequestStateEntries();\n      if (!(running_rsentries->size() <= models_[0]->GetNumAvailablePages())) {\n        // Even the decode cannot be performed.\n        // As a result, directly return without doing prefill.\n        return {};\n      }\n    }\n    int num_running_rsentries = static_cast<int>(running_rsentries->size());\n\n    Request request{nullptr};\n    for (const Request& request_candidate : estate->waiting_queue) {\n      if (request_candidate->generation_cfg->debug_config.disagg_config.kind ==\n          DisaggRequestKind::kPrepareReceive) {\n        request = request_candidate;\n        break;\n      }\n    }\n    if (!request.defined()) {\n      // No request to prepare for prefill.\n      return {};\n    }\n    TVM_FFI_ICHECK_EQ(\n        request->generation_cfg->debug_config.disagg_config.kv_window_begin.value_or(0), 0);\n\n    std::vector<PrefillInput> prefill_input_for_all_models;\n    prefill_input_for_all_models.reserve(models_.size());\n\n    // We first collect the inputs that can be prefilled for each model.\n    // The inputs for each model are expected to be exactly the same.\n    for (int i = 0; i < static_cast<int>(models_.size()); ++i) {\n      NVTXScopedRange nvtx_scope(\"Process request \" + request->id);\n\n      PrefillInput prefill_input;\n      // - Try to prefill pending requests.\n      int num_available_pages = models_[i]->GetNumAvailablePages();\n      int current_total_seq_len = models_[i]->GetCurrentTotalSequenceLength();\n\n      RequestState rstate = estate->GetRequestState(request);\n      bool prefill_stops = false;\n      for (int j = 1; j < static_cast<int>(rstate->entries.size()); ++j) {\n        TVM_FFI_ICHECK(rstate->entries[j]->mstates[i]->inputs.empty())\n            << \"Re-prefill of preempted requests is not supported by prefill preparation.\";\n      }\n      const RequestStateEntry& rsentry = rstate->entries[0];\n      TVM_FFI_ICHECK(!rsentry->mstates[i]->inputs.empty())\n          << \"The request entry must have pending inputs.\";\n\n      // Todo: handle the case that input length is 1.\n\n      int input_length = rsentry->mstates[i]->GetInputLength();\n      // Update the input length with the requested KV window, where \"[begin:end]\"\n      // means the KV range to prefill on a prefill instance.\n      int kv_window_begin =\n          request->generation_cfg->debug_config.disagg_config.kv_window_begin.value_or(0);\n      int kv_window_end =\n          request->generation_cfg->debug_config.disagg_config.kv_window_end.value_or(input_length);\n      TVM_FFI_ICHECK_EQ(kv_window_begin, 0);\n      if (kv_window_end < 0) {\n        kv_window_end = input_length + kv_window_end;\n      }\n      TVM_FFI_ICHECK_GE(kv_window_end, 0);\n      TVM_FFI_ICHECK_LT(kv_window_end, input_length)\n          << \"Prefill the full input on the remote machine is not supported.\";\n      int orig_input_length = input_length;\n      input_length = kv_window_end;\n\n      int num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) /\n                              engine_config_->kv_cache_page_size;\n      bool sliding_window_enabled = sliding_window_sizes_[i] != -1;\n      int num_required_pages_under_sliding_window = std::numeric_limits<int>::max();\n      if (sliding_window_enabled) {\n        // Sliding window for model i is enabled.\n        int max_single_request_page_requirement =\n            1 + (sliding_window_sizes_[i] + engine_config_->kv_cache_page_size - 1) /\n                    engine_config_->kv_cache_page_size;\n        int num_total_prefilled_tokens = rsentry->mstates[i]->num_prefilled_tokens;\n        int num_pages_in_use = (std::min(num_total_prefilled_tokens, sliding_window_sizes_[i]) +\n                                engine_config_->kv_cache_page_size - 1) /\n                               engine_config_->kv_cache_page_size;\n        num_required_pages_under_sliding_window =\n            max_single_request_page_requirement - num_pages_in_use;\n        num_require_pages = std::min(num_require_pages, num_required_pages_under_sliding_window);\n        TVM_FFI_ICHECK_GE(num_require_pages, 0);\n      }\n\n      // Check if the entire request state entry can fit for prefill.\n      bool can_prefill = false;\n      {\n        NVTXScopedRange nvtx_scope(\"Attempt\");\n        for (int num_child_to_activate = rsentry->child_indices.size(); num_child_to_activate >= 0;\n             --num_child_to_activate) {\n          while (!HasPrefillSpace(num_require_pages, sliding_window_enabled, num_running_rsentries,\n                                  num_available_pages, current_total_seq_len, input_length,\n                                  engine_config_->max_total_sequence_length)) {\n            if (!estate->prefix_cache->TryFreeMemory()) break;\n            // Update number of available pages after memory free.\n            num_available_pages = models_[i]->GetNumAvailablePages();\n            current_total_seq_len = models_[i]->GetCurrentTotalSequenceLength();\n          }\n          if (CanPrefill(estate, 1 + num_child_to_activate, input_length, num_require_pages,\n                         num_available_pages, current_total_seq_len, num_running_rsentries,\n                         kv_state_kind_, sliding_window_enabled)) {\n            prefill_input = {rsentry, input_length, num_child_to_activate, /*is_decode=*/false};\n            can_prefill = true;\n            break;\n          }\n        }\n      }\n      if (!can_prefill) {\n        return std::nullopt;\n      }\n      rsentry->mstates[i]->inputs =\n          SplitData(rsentry->mstates[i]->inputs, orig_input_length, kv_window_end).first;\n      prefill_input_for_all_models.push_back(prefill_input);\n    }\n\n    // Prefill inputs of all models should be the same.\n    TVM_FFI_ICHECK(!prefill_input_for_all_models.empty());\n    PrefillInput prefill_input = prefill_input_for_all_models[0];\n    {\n      NVTXScopedRange nvtx_scope(\"reduction\");\n      for (int i = 1; i < static_cast<int>(prefill_input_for_all_models.size()); ++i) {\n        TVM_FFI_ICHECK(prefill_input_for_all_models[i].rsentry.same_as(prefill_input.rsentry));\n        TVM_FFI_ICHECK_EQ(prefill_input_for_all_models[i].max_prefill_length,\n                          prefill_input.max_prefill_length);\n        TVM_FFI_ICHECK_EQ(prefill_input_for_all_models[i].num_child_to_activate,\n                          prefill_input.num_child_to_activate);\n      }\n    }\n\n    return prefill_input;\n  }\n\n  // Mimicked from BatchPrefillBaseActionObj::CanPrefill\n  bool CanPrefill(EngineState estate, int num_prefill_rsentries, int total_input_length,\n                  int num_required_pages, int num_available_pages, int current_total_seq_len,\n                  int num_running_rsentries, KVStateKind kv_state_kind,\n                  bool sliding_window_enabled) {\n    // No exceeding of the maximum allowed requests that can\n    // run simultaneously.\n    int spec_factor = engine_config_->speculative_mode != SpeculativeMode::kDisable\n                          ? (estate->spec_draft_length + 1)\n                          : 1;\n    if ((num_running_rsentries + num_prefill_rsentries) * spec_factor >\n        std::min(static_cast<int64_t>(engine_config_->max_num_sequence),\n                 engine_config_->prefill_chunk_size)) {\n      return false;\n    }\n\n    // NOTE: The conditions are heuristic and can be revised.\n    // Cond 1: at least one decode can be performed after prefill.\n    // Cond 2: number of total tokens after \"x\" times of decode does not\n    // exceed the limit, where \"x\" is a watermark number can\n    // be configured and adjusted in the future.\n    if (num_required_pages + 400 > num_available_pages) {\n      return false;\n    }\n    return HasPrefillSpace(num_required_pages, sliding_window_enabled,\n                           (num_running_rsentries + num_prefill_rsentries), num_available_pages,\n                           current_total_seq_len, total_input_length,\n                           engine_config_->max_total_sequence_length);\n  }\n\n  // Mimicked from NewRequestPrefillActionObj::MatchPrefixCache\n  int MatchPrefixCache(EngineState estate, PrefillInput* input) final {\n    RequestStateEntry rsentry = input->rsentry;\n    if (estate->prefix_cache->Mode() == PrefixCacheMode::kDisable) {\n      return 0;\n    }\n    if (rsentry->parent_idx == -1 && rsentry->status == RequestStateStatus::kPending &&\n        !estate->prefix_cache->HasSequence(rsentry->mstates[0]->internal_id)) {\n      std::vector<int32_t> tokens = GetConcatPrefillInputData(rsentry->mstates[0]);\n      if (tokens.empty()) {\n        // If the RequestStateEntry is of empty input data, or not fully tokenized, do nothing\n        // and return.\n        return 0;\n      }\n      PrefixCacheMatchedResult result = estate->prefix_cache->InsertSequence(\n          rsentry->mstates[0]->internal_id, tokens, models_[0]->GetSlidingWindowSize(),\n          models_[0]->GetAttentionSinkSize());\n\n      if (result.prefilled_offset == 0) {\n        // Add new sequence\n        TVM_FFI_ICHECK_EQ(result.forked_seq_id, -1);\n        TVM_FFI_ICHECK_EQ(result.reused_seq_id, -1);\n        TVM_FFI_ICHECK_EQ(result.reused_seq_pop_last_tokens, 0);\n        for (Model model : models_) {\n          model->AddNewSequence(rsentry->mstates[0]->internal_id);\n          // Enable sliding window for the sequence if it is not a parent.\n          if (rsentry->child_indices.empty()) {\n            model->EnableSlidingWindowForSeq(rsentry->mstates[0]->internal_id);\n          }\n        }\n      } else {\n        if (result.forked_seq_id != -1) {\n          TVM_FFI_ICHECK_EQ(result.reused_seq_id, -1);\n          TVM_FFI_ICHECK_EQ(result.reused_seq_pop_last_tokens, 0);\n          // Fork from active sequence\n          for (Model model : models_) {\n            model->ForkSequence(result.forked_seq_id, rsentry->mstates[0]->internal_id,\n                                result.prefilled_offset);\n            // Enable sliding window for the sequence if it is not a parent.\n            if (rsentry->child_indices.empty()) {\n              model->EnableSlidingWindowForSeq(rsentry->mstates[0]->internal_id);\n            }\n          }\n        } else {\n          // Reuse recycling sequence\n          TVM_FFI_ICHECK_EQ(result.forked_seq_id, -1);\n          estate->id_manager.RecycleId(rsentry->mstates[0]->internal_id);\n          for (int i = 0; i < rsentry->mstates.size(); ++i) {\n            rsentry->mstates[i]->internal_id = result.reused_seq_id;\n          }\n          if (result.reused_seq_pop_last_tokens > 0) {\n            for (Model model : models_) {\n              model->PopNFromKVCache(rsentry->mstates[0]->internal_id,\n                                     result.reused_seq_pop_last_tokens);\n            }\n          }\n        }\n      }\n      // Pop matched prefix\n      if (result.prefilled_offset) {\n        for (int i = 0; i < rsentry->mstates.size(); ++i) {\n          PopPrefillInputData(rsentry->mstates[i], result.prefilled_offset);\n        }\n      }\n      // Update max prefill length\n      input->max_prefill_length =\n          std::min(input->max_prefill_length, rsentry->mstates[0]->GetInputLength());\n      return result.prefilled_offset;\n    }\n    return 0;\n  }\n\n  /*!\n   * \\brief The stream callback function to passes back the KV cache metadata\n   * and prefix matched length in prefix cache.\n   */\n  FRequestStreamCallback request_stream_callback_;\n};\n\nEngineAction EngineAction::DisaggPrepareReceive(Array<Model> models, EngineConfig engine_config,\n                                                std::vector<tvm::ffi::json::Object> model_configs,\n                                                Optional<EventTraceRecorder> trace_recorder,\n                                                FRequestStreamCallback request_stream_callback) {\n  return EngineAction(tvm::ffi::make_object<DisaggPrepareReceiveActionObj>(\n      std::move(models), std::move(engine_config), std::move(model_configs),\n      std::move(trace_recorder), std::move(request_stream_callback)));\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/engine_actions/disagg_remote_send.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_actions/new_request_prefill.cc\n */\n\n#include \"../sampler/sampler.h\"\n#include \"batch_prefill_base.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\n/*!\n * \\brief The action that prefills requests in the `waiting_queue` of\n * the engine state.\n * Aside from that, this action sends the computed KV data to remote\n * instances after computing the KV data.\n */\nclass DisaggRemoteSendActionObj : public BatchPrefillBaseActionObj {\n public:\n  explicit DisaggRemoteSendActionObj(Array<Model> models,\n                                     std::vector<ModelWorkspace> model_workspaces,\n                                     EngineConfig engine_config,\n                                     std::vector<tvm::ffi::json::Object> model_configs,\n                                     Optional<EventTraceRecorder> trace_recorder,\n                                     FRequestStreamCallback request_stream_callback, Device device)\n      : BatchPrefillBaseActionObj(std::move(models), std::move(engine_config),\n                                  std::move(model_configs), std::move(trace_recorder)),\n        model_workspaces_(std::move(model_workspaces)),\n        request_stream_callback_(std::move(request_stream_callback)),\n        device_(device) {\n    if (device.device_type == DLDeviceType::kDLCUDA ||\n        device.device_type == DLDeviceType::kDLROCM) {\n      // The compute stream is the default stream.\n      compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device);\n    }\n  }\n\n  // Mimicked from NewRequestPrefillActionObj::Step\n  Array<Request> Step(EngineState estate) final {\n    // - Find the requests in `waiting_queue` that can prefill in this step.\n    std::vector<PrefillInput> prefill_inputs;\n    {\n      NVTXScopedRange nvtx_scope(\"DisaggRemoteSend getting requests\");\n      prefill_inputs = GetRequestStateEntriesToPrefill(estate);\n      if (prefill_inputs.empty()) {\n        return {};\n      }\n    }\n\n    int num_rsentries = prefill_inputs.size();\n    {\n      NVTXScopedRange nvtx_scope(\"DisaggRemoteSend matching prefix\");\n      for (int i = 0; i < num_rsentries; ++i) {\n        MatchPrefixCache(estate, &prefill_inputs[i]);\n      }\n    }\n\n    auto tstart = std::chrono::high_resolution_clock::now();\n\n    // - Update status of request states from pending to alive.\n    Array<String> request_ids;\n    std::vector<RequestState> rstates_of_entries;\n    std::vector<RequestStateStatus> status_before_prefill;\n    UpdateRequestToAlive(prefill_inputs, estate, &request_ids, &rstates_of_entries,\n                         &status_before_prefill);\n\n    // - Get embedding and run prefill for each model.\n    // NOTE: we don't keep the logits as we don't run sampling in this action by design.\n    std::vector<int> prefill_lengths;\n    prefill_lengths.resize(/*size=*/num_rsentries, /*value=*/-1);\n    for (int model_id = 0; model_id < static_cast<int>(models_.size()); ++model_id) {\n      std::vector<int64_t> request_internal_ids;\n      request_internal_ids.reserve(num_rsentries);\n      ObjectRef embeddings = model_workspaces_[model_id].embeddings;\n      int cum_prefill_length = 0;\n      bool single_input =\n          num_rsentries == 1 && prefill_inputs[0].rsentry->mstates[model_id]->inputs.size() == 1;\n      std::vector<int64_t> cached_token_data;\n      for (int i = 0; i < num_rsentries; ++i) {\n        const RequestStateEntry& rsentry = prefill_inputs[i].rsentry;\n        RequestModelState mstate = rsentry->mstates[model_id];\n        auto [input_data, input_length] =\n            ChunkPrefillInputData(mstate, prefill_inputs[i].max_prefill_length);\n        if (prefill_lengths[i] == -1) {\n          prefill_lengths[i] = input_length;\n        } else {\n          TVM_FFI_ICHECK_EQ(prefill_lengths[i], input_length);\n        }\n        mstate->num_prefilled_tokens += input_length;\n\n        TVM_FFI_ICHECK(mstate->draft_output_tokens.empty());\n        TVM_FFI_ICHECK(mstate->draft_token_slots.empty());\n        if (status_before_prefill[i] == RequestStateStatus::kPending &&\n            !estate->prefix_cache->HasSequence(mstate->internal_id)) {\n          // Add the sequence to the model.\n          // If the sequence is already in prefix cache, it has also been added/forked in the\n          // KVCache.\n          TVM_FFI_ICHECK_EQ(rsentry->parent_idx, -1);\n          models_[model_id]->AddNewSequence(mstate->internal_id);\n          // Enable sliding window for the sequence if it is not a parent.\n          if (rsentry->child_indices.empty()) {\n            models_[model_id]->EnableSlidingWindowForSeq(mstate->internal_id);\n          }\n          DisaggConfig disagg_config = mstate->request->generation_cfg->debug_config.disagg_config;\n          TVM_FFI_ICHECK(disagg_config.dst_group_offset.has_value());\n          models_[model_id]->DisaggMarkKVSend(\n              mstate->internal_id, disagg_config.kv_window_begin.value_or(0),\n              disagg_config.kv_append_metadata[model_id], disagg_config.dst_group_offset.value());\n        }\n        request_internal_ids.push_back(mstate->internal_id);\n        RECORD_EVENT(trace_recorder_, rsentry->request->id, \"start embedding\");\n        for (int j = 0; j < static_cast<int>(input_data.size()); ++j) {\n          if (!model_id && !prefill_inputs[i].is_decode) {\n            mstate->prefilled_inputs.push_back(input_data[j]);\n          }\n          if (const auto* token_data = input_data[j].as<TokenDataNode>()) {\n            cached_token_data.insert(cached_token_data.end(), token_data->token_ids.begin(),\n                                     token_data->token_ids.end());\n          } else {\n            if (!cached_token_data.empty()) {\n              embeddings = TokenData(cached_token_data)\n                               ->GetEmbedding(models_[model_id],\n                                              /*dst=*/!single_input ? &embeddings : nullptr,\n                                              /*offset=*/cum_prefill_length);\n              cum_prefill_length += cached_token_data.size();\n              cached_token_data.clear();\n            }\n            embeddings = input_data[j]->GetEmbedding(models_[model_id],\n                                                     /*dst=*/!single_input ? &embeddings : nullptr,\n                                                     /*offset=*/cum_prefill_length);\n            cum_prefill_length += input_data[j]->GetLength();\n          }\n        }\n        RECORD_EVENT(trace_recorder_, rsentry->request->id, \"finish embedding\");\n      }\n      if (!cached_token_data.empty()) {\n        embeddings = TokenData(cached_token_data)\n                         ->GetEmbedding(models_[model_id],\n                                        /*dst=*/!single_input ? &embeddings : nullptr,\n                                        /*offset=*/cum_prefill_length);\n        cum_prefill_length += cached_token_data.size();\n        cached_token_data.clear();\n      }\n\n      RECORD_EVENT(trace_recorder_, request_ids, \"start prefill\");\n      Tensor logits =\n          models_[model_id]->BatchPrefill(embeddings, request_internal_ids, prefill_lengths);\n      RECORD_EVENT(trace_recorder_, request_ids, \"finish prefill\");\n      TVM_FFI_ICHECK_EQ(logits->ndim, 3);\n      TVM_FFI_ICHECK_EQ(logits->shape[0], 1);\n      TVM_FFI_ICHECK_EQ(logits->shape[1], num_rsentries);\n    }\n\n    // - Commit the prefix cache changes from previous round of action.\n    // Note: we commit prefix cache changes here to overlap this commit with the GPU execution.\n    estate->prefix_cache->CommitSequenceExtention();\n\n    // - We run synchronize to make sure that the prefill is finished.\n    // We need explicit synchronization because we don't do sampling in this action.\n    DeviceAPI::Get(device_)->StreamSync(device_, compute_stream_);\n\n    auto tend = std::chrono::high_resolution_clock::now();\n    estate->metrics.engine_prefill_time_sum += static_cast<double>((tend - tstart).count()) / 1e9;\n\n    std::vector<Request> processed_requests =\n        RemoveProcessedRequests(prefill_inputs, estate, rstates_of_entries);\n    estate->running_rsentries_changed = true;\n    return processed_requests;\n  }\n\n private:\n  // Mimicked from BatchPrefillBaseActionObj::GetRequestStateEntriesToPrefill\n  std::vector<PrefillInput> GetRequestStateEntriesToPrefill(EngineState estate) {\n    // Preempt request state entries when decode cannot apply.\n    const std::vector<RequestStateEntry>* running_rsentries;\n    {\n      NVTXScopedRange nvtx_scope(\"BatchDecode getting requests\");\n      running_rsentries = &estate->GetRunningRequestStateEntries();\n      if (!(running_rsentries->size() <= models_[0]->GetNumAvailablePages())) {\n        // Even the decode cannot be performed.\n        // As a result, directly return without doing prefill.\n        return {};\n      }\n    }\n\n    // Explicitly filter the waiting queue to only keep the requests\n    // with disaggregation request kind \"kRemoteSend\".\n    std::vector<Request> waiting_queue;\n    waiting_queue.reserve(estate->waiting_queue.size());\n    for (Request request : estate->waiting_queue) {\n      if (request->generation_cfg->debug_config.disagg_config.kind ==\n          DisaggRequestKind::kRemoteSend) {\n        waiting_queue.push_back(request);\n      }\n    }\n    if (waiting_queue.empty()) {\n      // No request to prefill.\n      return {};\n    }\n\n    std::vector<std::vector<PrefillInput>> prefill_inputs_for_all_models;\n    prefill_inputs_for_all_models.reserve(models_.size());\n\n    int num_running_rsentries = static_cast<int>(running_rsentries->size());\n    // We first collect the inputs that can be prefilled for each model.\n    // Then we make a reduction to return the maximum common inputs.\n    for (int i = 0; i < static_cast<int>(models_.size()); ++i) {\n      std::vector<PrefillInput> prefill_inputs;\n      // - Try to prefill pending requests.\n      int total_input_length = 0;\n      int total_required_pages = 0;\n      int num_available_pages;\n      int current_total_seq_len;\n      {\n        NVTXScopedRange nvtx_scope(\"KV cache GetNumAvailablePages\");\n        num_available_pages = models_[i]->GetNumAvailablePages();\n      }\n      {\n        NVTXScopedRange nvtx_scope(\"KV cache GetCurrentTotalSequenceLength\");\n        current_total_seq_len = models_[i]->GetCurrentTotalSequenceLength();\n      }\n\n      int num_prefill_rsentries = 0;\n      for (const Request& request : waiting_queue) {\n        NVTXScopedRange nvtx_scope(\"Process request \" + request->id);\n        RequestState rstate = estate->GetRequestState(request);\n        TVM_FFI_ICHECK_EQ(rstate->entries.size(), 1) << \"n > 1 is not supported.\";\n        const RequestStateEntry& rsentry = rstate->entries[0];\n        TVM_FFI_ICHECK(!rsentry->mstates[i]->inputs.empty())\n            << \"The request entry must have pending inputs.\";\n\n        int input_length = rsentry->mstates[i]->GetInputLength();\n        int num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) /\n                                engine_config_->kv_cache_page_size;\n        bool sliding_window_enabled = sliding_window_sizes_[i] != -1;\n        int num_required_pages_under_sliding_window = std::numeric_limits<int>::max();\n        if (sliding_window_enabled) {\n          // Sliding window for model i is enabled.\n          int max_single_request_page_requirement =\n              1 + (sliding_window_sizes_[i] + engine_config_->kv_cache_page_size - 1) /\n                      engine_config_->kv_cache_page_size;\n          int num_total_prefilled_tokens = rsentry->mstates[i]->num_prefilled_tokens;\n          int parent_ptr = rsentry->parent_idx;\n          while (parent_ptr != -1) {\n            num_total_prefilled_tokens +=\n                rstate->entries[parent_ptr]->mstates[i]->num_prefilled_tokens;\n            parent_ptr = rstate->entries[parent_ptr]->parent_idx;\n          }\n\n          int num_pages_in_use = (std::min(num_total_prefilled_tokens, sliding_window_sizes_[i]) +\n                                  engine_config_->kv_cache_page_size - 1) /\n                                 engine_config_->kv_cache_page_size;\n          num_required_pages_under_sliding_window =\n              max_single_request_page_requirement - num_pages_in_use;\n          num_require_pages = std::min(num_require_pages, num_required_pages_under_sliding_window);\n          TVM_FFI_ICHECK_GE(num_require_pages, 0);\n        }\n\n        total_input_length += input_length;\n        total_required_pages += num_require_pages;\n        // - Attempt 1. Check if the entire request state entry can fit for prefill.\n        bool can_prefill = false;\n        {\n          NVTXScopedRange nvtx_scope(\"Attempt 1\");\n          for (int num_child_to_activate = rsentry->child_indices.size();\n               num_child_to_activate >= 0; --num_child_to_activate) {\n            while (!HasPrefillSpace(total_required_pages, sliding_window_enabled,\n                                    (num_running_rsentries + num_prefill_rsentries),\n                                    num_available_pages, current_total_seq_len, total_input_length,\n                                    engine_config_->max_total_sequence_length)) {\n              if (!estate->prefix_cache->TryFreeMemory()) break;\n              // Update number of available pages after memory free.\n              num_available_pages = models_[i]->GetNumAvailablePages();\n              current_total_seq_len = models_[i]->GetCurrentTotalSequenceLength();\n            }\n            if (CanPrefill(estate, num_prefill_rsentries + 1 + num_child_to_activate,\n                           total_input_length, total_required_pages, num_available_pages,\n                           current_total_seq_len, num_running_rsentries, kv_state_kind_,\n                           sliding_window_enabled)) {\n              prefill_inputs.push_back(\n                  {rsentry, input_length, num_child_to_activate, /*is_decode=*/false});\n              num_prefill_rsentries += 1 + num_child_to_activate;\n              can_prefill = true;\n              break;\n            }\n          }\n        }\n        if (can_prefill) {\n          continue;\n        }\n        total_input_length -= input_length;\n        total_required_pages -= num_require_pages;\n\n        // - Attempt 2. Check if the request state entry can partially fit by input chunking.\n        TVM_FFI_ICHECK_LE(total_input_length, engine_config_->prefill_chunk_size);\n        if (engine_config_->prefill_chunk_size - total_input_length >= input_length ||\n            engine_config_->prefill_chunk_size == total_input_length) {\n          // 1. If the input length can fit the remaining prefill chunk size,\n          // it means the failure of attempt 1 is not because of the input\n          // length being too long, and thus chunking does not help.\n          // 2. If the total input length already reaches the prefill chunk size,\n          // the current request state entry will not be able to be processed.\n          // So we can safely return in either case.\n          break;\n        }\n        input_length = engine_config_->prefill_chunk_size - total_input_length;\n        num_require_pages = (input_length + engine_config_->kv_cache_page_size - 1) /\n                            engine_config_->kv_cache_page_size;\n        if (sliding_window_enabled) {\n          // Sliding window for model i is enabled.\n          num_require_pages = std::min(num_require_pages, num_required_pages_under_sliding_window);\n          TVM_FFI_ICHECK_GE(num_require_pages, 0);\n        }\n\n        {\n          NVTXScopedRange nvtx_scope(\"Attempt 2\");\n          total_input_length += input_length;\n          total_required_pages += num_require_pages;\n          if (CanPrefill(estate, num_prefill_rsentries + 1, total_input_length,\n                         total_required_pages, num_available_pages, current_total_seq_len,\n                         num_running_rsentries, kv_state_kind_, sliding_window_enabled)) {\n            prefill_inputs.push_back({rsentry, input_length, 0, /*is_decode=*/false});\n          }\n        }\n\n        // - Prefill stops here.\n        break;\n      }\n      prefill_inputs_for_all_models.push_back(prefill_inputs);\n    }\n\n    // Reduce over the prefill inputs of all models.\n    TVM_FFI_ICHECK(!prefill_inputs_for_all_models.empty());\n    int num_prefill_inputs = prefill_inputs_for_all_models[0].size();\n    for (int i = 1; i < static_cast<int>(prefill_inputs_for_all_models.size()); ++i) {\n      num_prefill_inputs =\n          std::min(num_prefill_inputs, static_cast<int>(prefill_inputs_for_all_models[i].size()));\n    }\n\n    if (num_prefill_inputs == 0) {\n      return {};\n    }\n\n    // Add the decode requests to the prefill inputs if prefill mode is hybrid.\n    std::vector<PrefillInput> prefill_inputs(prefill_inputs_for_all_models[0].begin(),\n                                             prefill_inputs_for_all_models[0].end());\n    {\n      NVTXScopedRange nvtx_scope(\"reduction\");\n      for (int i = 1; i < static_cast<int>(prefill_inputs_for_all_models.size()); ++i) {\n        // Prefill input lengths except the last one are supposed to be the same for all models.\n        for (int j = 0; j < num_prefill_inputs - 1; ++j) {\n          TVM_FFI_ICHECK(\n              prefill_inputs_for_all_models[i][j].rsentry.same_as(prefill_inputs[j].rsentry));\n          TVM_FFI_ICHECK_EQ(prefill_inputs_for_all_models[i][j].max_prefill_length,\n                            prefill_inputs[j].max_prefill_length);\n          prefill_inputs[j].num_child_to_activate =\n              std::min(prefill_inputs[j].num_child_to_activate,\n                       prefill_inputs_for_all_models[i][j].num_child_to_activate);\n        }\n        // The input length of the last input is the minimum among all models.\n        TVM_FFI_ICHECK(prefill_inputs_for_all_models[i][num_prefill_inputs - 1].rsentry.same_as(\n            prefill_inputs[num_prefill_inputs - 1].rsentry));\n        prefill_inputs[num_prefill_inputs - 1].max_prefill_length =\n            std::min(prefill_inputs[num_prefill_inputs - 1].max_prefill_length,\n                     prefill_inputs_for_all_models[i][num_prefill_inputs - 1].max_prefill_length);\n        prefill_inputs[num_prefill_inputs - 1].num_child_to_activate = std::min(\n            prefill_inputs[num_prefill_inputs - 1].num_child_to_activate,\n            prefill_inputs_for_all_models[i][num_prefill_inputs - 1].num_child_to_activate);\n      }\n    }\n\n    return prefill_inputs;\n  }\n\n  // Copied from NewRequestPrefillActionObj::MatchPrefixCache\n  /*!\n   * \\brief Match the request state entry with prefix cache, to skip prefilling common prefix\n   * tokens. If the request state entry is not added to KVCache yet, this method will add/fork the\n   * request in the KVCache, depending on the matching result from prefix cache.\n   * \\param estate The engine state.\n   * \\param[in, out] input The prefill input to be matched and updated.\n   * \\return The matched length in prefix cache.\n   */\n  int MatchPrefixCache(EngineState estate, PrefillInput* input) final {\n    RequestStateEntry rsentry = input->rsentry;\n    if (estate->prefix_cache->Mode() == PrefixCacheMode::kDisable) {\n      return 0;\n    }\n    if (rsentry->parent_idx == -1 && rsentry->status == RequestStateStatus::kPending &&\n        !estate->prefix_cache->HasSequence(rsentry->mstates[0]->internal_id)) {\n      std::vector<int32_t> tokens = GetConcatPrefillInputData(rsentry->mstates[0]);\n      if (tokens.empty()) {\n        // If the RequestStateEntry is of empty input data, or not fully tokenized, do nothing\n        // and return.\n        return 0;\n      }\n      PrefixCacheMatchedResult result = estate->prefix_cache->InsertSequence(\n          rsentry->mstates[0]->internal_id, tokens, models_[0]->GetSlidingWindowSize(),\n          models_[0]->GetAttentionSinkSize());\n\n      if (result.prefilled_offset == 0) {\n        // Add new sequence\n        TVM_FFI_ICHECK_EQ(result.forked_seq_id, -1);\n        TVM_FFI_ICHECK_EQ(result.reused_seq_id, -1);\n        TVM_FFI_ICHECK_EQ(result.reused_seq_pop_last_tokens, 0);\n        for (int model_id = 0; model_id < static_cast<int>(models_.size()); ++model_id) {\n          Model model = models_[model_id];\n          RequestModelState mstate = rsentry->mstates[model_id];\n          model->AddNewSequence(rsentry->mstates[0]->internal_id);\n          // Enable sliding window for the sequence if it is not a parent.\n          if (rsentry->child_indices.empty()) {\n            model->EnableSlidingWindowForSeq(rsentry->mstates[0]->internal_id);\n          }\n          DisaggConfig disagg_config = mstate->request->generation_cfg->debug_config.disagg_config;\n          models_[model_id]->DisaggMarkKVSend(\n              mstate->internal_id, disagg_config.kv_window_begin.value_or(0),\n              disagg_config.kv_append_metadata[model_id], disagg_config.dst_group_offset.value());\n        }\n      } else {\n        if (result.forked_seq_id != -1) {\n          TVM_FFI_ICHECK_EQ(result.reused_seq_id, -1);\n          TVM_FFI_ICHECK_EQ(result.reused_seq_pop_last_tokens, 0);\n          // Fork from active sequence\n          for (int model_id = 0; model_id < static_cast<int>(models_.size()); ++model_id) {\n            Model model = models_[model_id];\n            RequestModelState mstate = rsentry->mstates[model_id];\n            model->ForkSequence(result.forked_seq_id, rsentry->mstates[0]->internal_id,\n                                result.prefilled_offset);\n            // Enable sliding window for the sequence if it is not a parent.\n            if (rsentry->child_indices.empty()) {\n              model->EnableSlidingWindowForSeq(rsentry->mstates[0]->internal_id);\n            }\n            DisaggConfig disagg_config =\n                mstate->request->generation_cfg->debug_config.disagg_config;\n            models_[model_id]->DisaggMarkKVSend(\n                mstate->internal_id, disagg_config.kv_window_begin.value_or(0),\n                disagg_config.kv_append_metadata[model_id], disagg_config.dst_group_offset.value());\n          }\n        } else {\n          // Reuse recycling sequence\n          TVM_FFI_ICHECK_EQ(result.forked_seq_id, -1);\n          estate->id_manager.RecycleId(rsentry->mstates[0]->internal_id);\n          for (int i = 0; i < rsentry->mstates.size(); ++i) {\n            rsentry->mstates[i]->internal_id = result.reused_seq_id;\n          }\n          if (result.reused_seq_pop_last_tokens > 0) {\n            for (Model model : models_) {\n              model->PopNFromKVCache(rsentry->mstates[0]->internal_id,\n                                     result.reused_seq_pop_last_tokens);\n            }\n          }\n          for (int model_id = 0; model_id < static_cast<int>(models_.size()); ++model_id) {\n            RequestModelState mstate = rsentry->mstates[model_id];\n            DisaggConfig disagg_config =\n                mstate->request->generation_cfg->debug_config.disagg_config;\n            models_[model_id]->DisaggMarkKVSend(\n                mstate->internal_id, disagg_config.kv_window_begin.value_or(0),\n                disagg_config.kv_append_metadata[model_id], disagg_config.dst_group_offset.value());\n          }\n        }\n      }\n      // Pop matched prefix\n      if (result.prefilled_offset) {\n        for (int i = 0; i < rsentry->mstates.size(); ++i) {\n          PopPrefillInputData(rsentry->mstates[i], result.prefilled_offset);\n        }\n      }\n      // Update max prefill length\n      input->max_prefill_length =\n          std::min(input->max_prefill_length, rsentry->mstates[0]->GetInputLength());\n      return result.prefilled_offset;\n    }\n    return 0;\n  }\n\n  /*! \\brief Workspace of each model. */\n  std::vector<ModelWorkspace> model_workspaces_;\n  /*! \\brief The stream callback function to passes back the sampled results after prefill. */\n  FRequestStreamCallback request_stream_callback_;\n  /*! \\brief The device which we run synchronization for after prefill. */\n  Device device_;\n  /*! \\brief The compute stream to run synchronization for. */\n  TVMStreamHandle compute_stream_ = nullptr;\n};\n\nEngineAction EngineAction::DisaggRemoteSend(\n    Array<Model> models, std::vector<ModelWorkspace> model_workspaces, EngineConfig engine_config,\n    std::vector<tvm::ffi::json::Object> model_configs, Optional<EventTraceRecorder> trace_recorder,\n    FRequestStreamCallback request_stream_callback, Device device) {\n  return EngineAction(tvm::ffi::make_object<DisaggRemoteSendActionObj>(\n      std::move(models), std::move(model_workspaces), std::move(engine_config),\n      std::move(model_configs), std::move(trace_recorder), std::move(request_stream_callback),\n      device));\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/engine_actions/eagle_batch_draft.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_actions/eagle_batch_draft.cc\n */\n\n#include <numeric>\n\n#include \"../config.h\"\n#include \"../model.h\"\n#include \"../sampler/sampler.h\"\n#include \"action.h\"\n#include \"action_commons.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\n/*!\n * \\brief The action that runs draft proposal for requests in the\n * `running_queue` of engine state. Preempt low-priority requests\n * accordingly when it is impossible to decode all the running requests.\n */\nclass EagleBatchDraftActionObj : public EngineActionObj {\n public:\n  explicit EagleBatchDraftActionObj(Array<Model> models, LogitProcessor logit_processor,\n                                    Sampler sampler, std::vector<ModelWorkspace> model_workspaces,\n                                    DraftTokenWorkspaceManager draft_token_workspace_manager,\n                                    EngineConfig engine_config,\n                                    Optional<EventTraceRecorder> trace_recorder)\n      : models_(std::move(models)),\n        logit_processor_(std::move(logit_processor)),\n        sampler_(std::move(sampler)),\n        model_workspaces_(std::move(model_workspaces)),\n        draft_token_workspace_manager_(std::move(draft_token_workspace_manager)),\n        engine_config_(std::move(engine_config)),\n        trace_recorder_(std::move(trace_recorder)) {}\n\n  Array<Request> Step(EngineState estate) final {\n    // - Only run spec decode when there are two models (llm+ssm) and >=1 running requests.\n    if (models_.size() != 2 || estate->running_queue.empty()) {\n      return {};\n    }\n\n    // Preempt request state entries when decode cannot apply.\n    std::vector<RequestStateEntry> running_rsentries = estate->GetRunningRequestStateEntries();\n    while (!CanDecode(running_rsentries.size())) {\n      if (estate->prefix_cache->TryFreeMemory()) continue;\n      RequestStateEntry preempted = PreemptLastRunningRequestStateEntry(\n          estate, models_, draft_token_workspace_manager_, trace_recorder_);\n      if (preempted.same_as(running_rsentries.back())) {\n        running_rsentries.pop_back();\n      }\n    }\n\n    auto tstart = std::chrono::high_resolution_clock::now();\n\n    int num_rsentries = running_rsentries.size();\n    TVM_FFI_ICHECK_GT(num_rsentries, 0)\n        << \"There should be at least one request state entry that can run decode. \"\n           \"Possible failure reason: none of the prefill phase of the running requests is finished\";\n    TVM_FFI_ICHECK_LE(num_rsentries, engine_config_->max_num_sequence)\n        << \"The number of running requests exceeds the max number of sequence in EngineConfig. \"\n           \"Possible failure reason: the prefill action allows new sequence in regardless of the \"\n           \"max num sequence.\";\n\n    Array<String> request_ids;\n    std::vector<int64_t> request_internal_ids;\n    Array<GenerationConfig> generation_cfg;\n    std::vector<RandomGenerator*> rngs;\n    std::vector<std::vector<int>> draft_token_indices;\n    request_ids.reserve(num_rsentries);\n    request_internal_ids.reserve(num_rsentries);\n    generation_cfg.reserve(num_rsentries);\n    draft_token_indices.reserve(num_rsentries);\n    for (const RequestStateEntry& rsentry : running_rsentries) {\n      request_ids.push_back(rsentry->request->id);\n      request_internal_ids.push_back(rsentry->mstates[0]->internal_id);\n      generation_cfg.push_back(rsentry->request->generation_cfg);\n      rngs.push_back(&rsentry->rng);\n    }\n\n    TVM_FFI_ICHECK_GT(estate->spec_draft_length, 0)\n        << \"The speculative decoding draft length must be positive.\";\n    // The first model doesn't get involved in draft proposal.\n    for (int model_id = 1; model_id < static_cast<int>(models_.size()); ++model_id) {\n      // Collect\n      // - the last committed token,\n      // - the request model state\n      // of each request.\n      std::vector<int> input_tokens;\n      Array<RequestModelState> mstates;\n      input_tokens.reserve(num_rsentries);\n      mstates.reserve(num_rsentries);\n      for (const RequestStateEntry& rsentry : running_rsentries) {\n        mstates.push_back(rsentry->mstates[model_id]);\n      }\n      // draft_length_ rounds of draft proposal.\n      ObjectRef hidden_states = model_workspaces_[model_id].hidden_states;\n      // Concat last hidden_states\n      draft_token_slots_.clear();\n      if (estate->spec_draft_length > 1) {\n        for (int i = 0; i < num_rsentries; ++i) {\n          draft_token_slots_.push_back(mstates[i]->draft_token_slots.back());\n        }\n        hidden_states = models_[model_id]->GatherHiddenStates(\n            model_workspaces_[0].draft_hidden_states_storage, draft_token_slots_, &hidden_states);\n      }\n      // The first draft token has been generated in prefill/verify stage\n      for (int draft_id = 1; draft_id < estate->spec_draft_length; ++draft_id) {\n        draft_token_indices.clear();\n        auto tdraft_start = std::chrono::high_resolution_clock::now();\n        // prepare new input tokens\n        input_tokens.clear();\n        for (int i = 0; i < num_rsentries; ++i) {\n          TVM_FFI_ICHECK(!mstates[i]->draft_output_tokens.empty());\n          input_tokens.push_back(mstates[i]->draft_output_tokens.back().GetTokenId());\n          draft_token_indices.emplace_back(\n              std::vector<int>{static_cast<int>(mstates[i]->draft_output_tokens.size() - 1)});\n        }\n\n        // - Compute embeddings.\n        RECORD_EVENT(trace_recorder_, request_ids, \"start proposal embedding\");\n        ObjectRef embeddings =\n            models_[model_id]->TokenEmbed({IntTuple{input_tokens.begin(), input_tokens.end()}});\n        RECORD_EVENT(trace_recorder_, request_ids, \"finish proposal embedding\");\n\n        // - Invoke model decode.\n        RECORD_EVENT(trace_recorder_, request_ids, \"start proposal decode\");\n        ObjectRef fused_embedding_hidden_states = models_[model_id]->FuseEmbedHidden(\n            embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1);\n        hidden_states = models_[model_id]->BatchDecodeToLastHidden(fused_embedding_hidden_states,\n                                                                   request_internal_ids);\n        Tensor logits;\n        if (models_[model_id]->CanGetLogits()) {\n          logits = models_[model_id]->GetLogits(hidden_states);\n        } else {\n          // - Use base model's head.\n          logits = models_[0]->GetLogits(hidden_states);\n        }\n        RECORD_EVENT(trace_recorder_, request_ids, \"finish proposal decode\");\n        TVM_FFI_ICHECK_EQ(logits->ndim, 2);\n        TVM_FFI_ICHECK_EQ(logits->shape[0], num_rsentries);\n\n        // - Update logits.\n        logit_processor_->InplaceUpdateLogits(logits, generation_cfg, mstates, request_ids, nullptr,\n                                              &mstates, &draft_token_indices);\n\n        // - Compute probability distributions.\n        Tensor probs_on_device =\n            logit_processor_->ComputeProbsFromLogits(logits, generation_cfg, request_ids);\n\n        // - Commit the prefix cache changes from previous round of action.\n        // Note: we commit prefix cache changes here to overlap this commit with the GPU execution.\n        estate->prefix_cache->CommitSequenceExtention();\n\n        // - Sample tokens.\n        // Fill range [0, num_rsentries) into `sample_indices`.\n        std::vector<int> sample_indices(num_rsentries);\n        std::iota(sample_indices.begin(), sample_indices.end(), 0);\n        Tensor renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(\n            probs_on_device, sample_indices, request_ids, generation_cfg);\n        std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(\n            renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);\n        TVM_FFI_ICHECK_EQ(sample_results.size(), num_rsentries);\n\n        // - Add draft token to the state.\n        draft_token_workspace_manager_->AllocSlots(num_rsentries, &draft_token_slots_);\n        models_[model_id]->ScatterDraftProbs(probs_on_device, draft_token_slots_,\n                                             &model_workspaces_[0].draft_probs_storage);\n        // No need to save hidden states as they are not used by subsequent engine actions\n        for (int i = 0; i < num_rsentries; ++i) {\n          int64_t parent_idx = static_cast<int64_t>(mstates[i]->draft_output_tokens.size()) - 1;\n          mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i], parent_idx);\n        }\n\n        auto tdraft_end = std::chrono::high_resolution_clock::now();\n        estate->metrics.UpdateDraftTimeByBatchSize(\n            num_rsentries, static_cast<double>((tdraft_end - tdraft_start).count()) / 1e9);\n      }\n    }\n\n    auto tend = std::chrono::high_resolution_clock::now();\n    estate->metrics.engine_decode_time_sum += static_cast<double>((tend - tstart).count()) / 1e9;\n\n    return {};\n  }\n\n private:\n  /*! \\brief Check if the input requests can be decoded under conditions. */\n  bool CanDecode(int num_rsentries) {\n    // The first model is not involved in draft proposal.\n    for (int model_id = 1; model_id < static_cast<int>(models_.size()); ++model_id) {\n      // Check if the model has enough available pages.\n      int num_available_pages = models_[model_id]->GetNumAvailablePages();\n      if (num_rsentries > num_available_pages) {\n        return false;\n      }\n    }\n    return true;\n  }\n\n  /*! \\brief The model to run draft generation in speculative decoding. */\n  Array<Model> models_;\n  /*! \\brief The logit processor. */\n  LogitProcessor logit_processor_;\n  /*! \\brief The sampler to sample new tokens. */\n  Sampler sampler_;\n  /*! \\brief Workspace of each model. */\n  std::vector<ModelWorkspace> model_workspaces_;\n  /*! \\brief The draft token workspace manager. */\n  DraftTokenWorkspaceManager draft_token_workspace_manager_;\n  /*! \\brief The engine config. */\n  EngineConfig engine_config_;\n  /*! \\brief Event trace recorder. */\n  Optional<EventTraceRecorder> trace_recorder_;\n  /*! \\brief Temporary buffer to store the slots of the current draft tokens */\n  std::vector<int> draft_token_slots_;\n};\n\nEngineAction EngineAction::EagleBatchDraft(Array<Model> models, LogitProcessor logit_processor,\n                                           Sampler sampler,\n                                           std::vector<ModelWorkspace> model_workspaces,\n                                           DraftTokenWorkspaceManager draft_token_workspace_manager,\n                                           EngineConfig engine_config,\n                                           Optional<EventTraceRecorder> trace_recorder) {\n  return EngineAction(tvm::ffi::make_object<EagleBatchDraftActionObj>(\n      std::move(models), std::move(logit_processor), std::move(sampler),\n      std::move(model_workspaces), std::move(draft_token_workspace_manager),\n      std::move(engine_config), std::move(trace_recorder)));\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/engine_actions/eagle_batch_verify.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_actions/eagle_batch_verify.cc\n */\n\n#include <tvm/runtime/threading_backend.h>\n\n#include <cmath>\n#include <exception>\n#include <numeric>\n\n#include \"../../support/random.h\"\n#include \"../config.h\"\n#include \"../model.h\"\n#include \"../sampler/sampler.h\"\n#include \"action.h\"\n#include \"action_commons.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\n/*!\n * \\brief The action that runs verification for requests in the\n * `running_queue` of engine state. Preempt low-priority requests\n * accordingly when it is impossible to decode all the running requests.\n */\nclass EagleBatchVerifyActionObj : public EngineActionObj {\n public:\n  explicit EagleBatchVerifyActionObj(Array<Model> models, LogitProcessor logit_processor,\n                                     Sampler sampler, std::vector<ModelWorkspace> model_workspaces,\n                                     DraftTokenWorkspaceManager draft_token_workspace_manager,\n                                     EngineConfig engine_config,\n                                     Optional<EventTraceRecorder> trace_recorder)\n      : models_(std::move(models)),\n        logit_processor_(std::move(logit_processor)),\n        sampler_(std::move(sampler)),\n        model_workspaces_(std::move(model_workspaces)),\n        draft_token_workspace_manager_(std::move(draft_token_workspace_manager)),\n        engine_config_(std::move(engine_config)),\n        trace_recorder_(std::move(trace_recorder)),\n        rng_(RandomGenerator::GetInstance()) {}\n\n  Array<Request> Step(EngineState estate) final {\n    // - Only run spec decode when there are two models (llm+ssm) and >=1 running requests.\n    if (models_.size() != 2 || estate->running_queue.empty()) {\n      return {};\n    }\n\n    const auto& [rsentries, draft_lengths, total_draft_length] = GetDraftsToVerify(estate);\n    TVM_FFI_ICHECK_EQ(rsentries.size(), draft_lengths.size());\n    if (rsentries.empty()) {\n      return {};\n    }\n\n    auto tstart = std::chrono::high_resolution_clock::now();\n    int num_rsentries = rsentries.size();\n    Array<String> request_ids =\n        rsentries.Map([](const RequestStateEntry& rstate) { return rstate->request->id; });\n\n    // - Get embedding and run verify.\n    std::vector<int64_t> request_internal_ids;\n    std::vector<int32_t> all_tokens_to_verify;\n    Array<RequestModelState> verify_request_mstates;\n    Array<RequestModelState> draft_request_mstates;\n    Array<GenerationConfig> generation_cfg;\n    std::vector<RandomGenerator*> rngs;\n    std::vector<std::vector<SampleResult>> draft_output_tokens;\n    std::vector<std::vector<int>> draft_token_indices;\n    std::vector<int64_t> token_tree_parent_ptr;\n    request_internal_ids.reserve(num_rsentries);\n    all_tokens_to_verify.reserve(total_draft_length);\n    token_tree_parent_ptr.reserve(total_draft_length);\n    verify_request_mstates.reserve(num_rsentries);\n    draft_request_mstates.reserve(num_rsentries);\n    rngs.reserve(num_rsentries);\n    generation_cfg.reserve(num_rsentries);\n    draft_output_tokens.reserve(num_rsentries);\n    draft_token_indices.reserve(num_rsentries);\n    draft_token_slots_.clear();\n\n    for (int i = 0; i < num_rsentries; ++i) {\n      RequestModelState verify_mstate = rsentries[i]->mstates[verify_model_id_];\n      RequestModelState draft_mstate = rsentries[i]->mstates[draft_model_id_];\n      request_internal_ids.push_back(verify_mstate->internal_id);\n      TVM_FFI_ICHECK(!draft_lengths.empty());\n      TVM_FFI_ICHECK_EQ(draft_lengths[i], draft_mstate->draft_output_tokens.size());\n      TVM_FFI_ICHECK_EQ(draft_lengths[i], draft_mstate->draft_token_slots.size());\n      // the last committed token + all the draft tokens but the last one.\n      all_tokens_to_verify.push_back(draft_mstate->committed_tokens.back().GetTokenId());\n      draft_token_slots_.push_back(0);  // placeholder for the last committed token\n      token_tree_parent_ptr.push_back(-1);\n\n      for (int j = 0; j < static_cast<int>(draft_mstate->draft_output_tokens.size()); ++j) {\n        all_tokens_to_verify.push_back(draft_mstate->draft_output_tokens[j].GetTokenId());\n        draft_token_slots_.push_back(draft_mstate->draft_token_slots[j]);\n        token_tree_parent_ptr.push_back(draft_mstate->draft_token_parent_idx[j] + 1);\n      }\n      std::vector<int> cur_draft_token_indices(draft_mstate->draft_output_tokens.size() + 1);\n      std::iota(cur_draft_token_indices.begin(), cur_draft_token_indices.end(), -1);\n      draft_token_indices.emplace_back(std::move(cur_draft_token_indices));\n      verify_request_mstates.push_back(verify_mstate);\n      draft_request_mstates.push_back(draft_mstate);\n      generation_cfg.push_back(rsentries[i]->request->generation_cfg);\n      rngs.push_back(&rsentries[i]->rng);\n      draft_output_tokens.push_back(draft_mstate->draft_output_tokens);\n    }\n\n    Tensor draft_probs_on_device = models_[draft_model_id_]->GatherDraftProbs(\n        model_workspaces_[verify_model_id_].draft_probs_storage, draft_token_slots_,\n        &model_workspaces_[verify_model_id_].draft_probs);\n\n    std::vector<int> cum_verify_lengths = {0};\n    cum_verify_lengths.reserve(num_rsentries + 1);\n    std::vector<int> verify_lengths;\n    for (int i = 0; i < num_rsentries; ++i) {\n      // Add one committed token.\n      verify_lengths.push_back(draft_lengths[i] + 1);\n      cum_verify_lengths.push_back(cum_verify_lengths.back() + verify_lengths.back());\n    }\n\n    RECORD_EVENT(trace_recorder_, request_ids, \"start verify embedding\");\n    ObjectRef embeddings = models_[verify_model_id_]->TokenEmbed(\n        {IntTuple{all_tokens_to_verify.begin(), all_tokens_to_verify.end()}});\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish verify embedding\");\n\n    RECORD_EVENT(trace_recorder_, request_ids, \"start verify\");\n    ObjectRef hidden_states = models_[verify_model_id_]->BatchVerifyToLastHidden(\n        embeddings, request_internal_ids, verify_lengths, token_tree_parent_ptr);\n    Tensor logits = models_[verify_model_id_]->GetLogits(hidden_states);\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish verify\");\n    TVM_FFI_ICHECK_EQ(logits->ndim, 2);\n    TVM_FFI_ICHECK_EQ(logits->shape[0], cum_verify_lengths.back());\n\n    // - Update logits.\n    logit_processor_->InplaceUpdateLogits(logits, generation_cfg, verify_request_mstates,\n                                          request_ids, &cum_verify_lengths, &draft_request_mstates,\n                                          &draft_token_indices);\n\n    // - Compute probability distributions.\n    Tensor probs_on_device = logit_processor_->ComputeProbsFromLogits(\n        logits, generation_cfg, request_ids, &cum_verify_lengths);\n\n    // - Commit the prefix cache changes from previous round of action.\n    // Note: we commit prefix cache changes here to overlap this commit with the GPU execution.\n    estate->prefix_cache->CommitSequenceExtention();\n\n    std::vector<int> sample_indices(num_rsentries);\n    std::iota(sample_indices.begin(), sample_indices.end(), 0);\n    Tensor renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(\n        probs_on_device, sample_indices, request_ids, generation_cfg);\n    auto [sample_results_arr, _] = sampler_->BatchVerifyDraftTokensWithProbAfterTopP(\n        renormalized_probs, request_ids, cum_verify_lengths, generation_cfg, rngs,\n        draft_output_tokens, token_tree_parent_ptr, draft_probs_on_device);\n    TVM_FFI_ICHECK_EQ(sample_results_arr.size(), num_rsentries);\n\n    // We collect the requests whose drafts are fully accepted.\n    // When a request's draft is fully accepted, there is an extra token proposed\n    // by the draft model but not added into the draft model's KV cache.\n    // In this case, an additional batch decode step is needed for these requests.\n    std::vector<int64_t> fully_accepted_rsentries;\n    std::vector<int64_t> verify_model_seq_internal_ids;\n    std::vector<int64_t> accepted_token_tree_leaf_nodes;\n    fully_accepted_rsentries.reserve(num_rsentries);\n    verify_model_seq_internal_ids.reserve(num_rsentries);\n    accepted_token_tree_leaf_nodes.reserve(num_rsentries);\n\n    std::vector<int> last_accepted_hidden_positions;\n    last_accepted_hidden_positions.reserve(num_rsentries);\n    for (int i = 0; i < num_rsentries; ++i) {\n      const std::vector<SampleResult>& sample_results = sample_results_arr[i];\n      int accept_length = sample_results.size();\n      TVM_FFI_ICHECK_GE(accept_length, 1);\n      for (SampleResult sample_result : sample_results) {\n        rsentries[i]->mstates[verify_model_id_]->CommitToken(sample_result);\n        rsentries[i]->mstates[draft_model_id_]->CommitToken(sample_result);\n      }\n      // Metrics update\n      // live update the output metrics\n      rsentries[i]->rstate->metrics.completion_tokens += accept_length;\n      rsentries[i]->rstate->metrics.decode_tokens += accept_length;\n      estate->metrics.spec_decode.Update(cum_verify_lengths[i + 1] - cum_verify_lengths[i],\n                                         accept_length);\n      // - Minus one because the last draft token has no kv cache entry\n      // - Take max with 0 in case of all accepted.\n      int rollback_length =\n          std::max(cum_verify_lengths[i + 1] - cum_verify_lengths[i] - accept_length, 0);\n\n      // Commit accepted tokens to the \"verify_model\", rollback kv cache\n      // in the \"draft_model\".\n      // NOTE: when number of small models is more than 1 (in the future),\n      // it is possible to re-compute prefill for the small models.\n      verify_model_seq_internal_ids.push_back(rsentries[i]->mstates[verify_model_id_]->internal_id);\n      accepted_token_tree_leaf_nodes.push_back(accept_length - 1);\n      if (rollback_length > 0) {\n        // Draft model rollback minus one because verify uses one more token.\n        models_[draft_model_id_]->PopNFromKVCache(\n            rsentries[i]->mstates[draft_model_id_]->internal_id, rollback_length - 1);\n      } else {\n        fully_accepted_rsentries.push_back(i);\n      }\n      // clear the draft model state entries\n      rsentries[i]->mstates[draft_model_id_]->RemoveAllDraftTokens(&draft_token_slots_);\n      draft_token_workspace_manager_->FreeSlots(draft_token_slots_);\n      // - Slice and save hidden_states_for_sample\n      last_accepted_hidden_positions.push_back(cum_verify_lengths[i] + accept_length - 1);\n    }\n    models_[verify_model_id_]->CommitAcceptedTokenTreeNodesToKVCache(\n        verify_model_seq_internal_ids, accepted_token_tree_leaf_nodes);\n    if (!fully_accepted_rsentries.empty() &&\n        engine_config_->speculative_mode == SpeculativeMode::kEagle) {\n      // - Run a step of batch decode for requests whose drafts are fully accepted.\n      // When a request's draft is fully accepted, there is an extra token proposed\n      // by the draft model but not added into the draft model's KV cache.\n      // In this case, an additional batch decode step is needed for these requests.\n      std::vector<int> input_tokens;\n      std::vector<int64_t> fully_accepted_request_internal_ids;\n      input_tokens.reserve(fully_accepted_rsentries.size());\n      fully_accepted_request_internal_ids.reserve(fully_accepted_rsentries.size());\n\n      std::vector<int> hidden_states_positions_for_fully_accepted;\n      hidden_states_positions_for_fully_accepted.reserve(fully_accepted_rsentries.size());\n\n      for (int rsentry_id : fully_accepted_rsentries) {\n        int num_committed_tokens =\n            rsentries[rsentry_id]->mstates[verify_model_id_]->committed_tokens.size();\n        // When a request's draft is fully accepted, an additional new token is sampled.\n        // So the token needed to fill in the draft model is the committed_token[-2].\n        TVM_FFI_ICHECK_GE(num_committed_tokens, 2);\n        input_tokens.push_back(rsentries[rsentry_id]\n                                   ->mstates[verify_model_id_]\n                                   ->committed_tokens[num_committed_tokens - 2]\n                                   .GetTokenId());\n\n        // Taking the hidden states of the token before the last token\n        hidden_states_positions_for_fully_accepted.push_back(\n            last_accepted_hidden_positions[rsentry_id] - 1);\n        fully_accepted_request_internal_ids.push_back(\n            rsentries[rsentry_id]->mstates[draft_model_id_]->internal_id);\n      }\n\n      // - Compute embeddings.\n      ObjectRef embeddings = models_[draft_model_id_]->TokenEmbed(\n          {IntTuple{input_tokens.begin(), input_tokens.end()}});\n      // - Gather hidden states\n      ObjectRef hidden_states_for_fully_accepted = models_[draft_model_id_]->GatherHiddenStates(\n          hidden_states, hidden_states_positions_for_fully_accepted,\n          &model_workspaces_[draft_model_id_].hidden_states);\n      // - Invoke model decode.\n      ObjectRef fused_embedding_hidden_states = models_[draft_model_id_]->FuseEmbedHidden(\n          embeddings, hidden_states_for_fully_accepted,\n          /*batch_size*/ fully_accepted_rsentries.size(), /*seq_len*/ 1);\n      hidden_states_for_fully_accepted = models_[draft_model_id_]->BatchDecodeToLastHidden(\n          fused_embedding_hidden_states, fully_accepted_request_internal_ids);\n      // - We explicitly synchronize to avoid the input tokens getting overriden in the\n      // next runs of BatchDecode.\n      // This is because we do not do sample for this round of batch decode.\n      if (hidden_states_for_fully_accepted->IsInstance<DRefObj>()) {\n        Downcast<Session>(Downcast<DRef>(hidden_states_for_fully_accepted)->session)->SyncWorker(0);\n      } else {\n        Tensor hidden_states_for_fully_accepted_nd =\n            Downcast<Tensor>(hidden_states_for_fully_accepted);\n        DeviceAPI::Get(hidden_states_for_fully_accepted_nd->device)\n            ->StreamSync(hidden_states_for_fully_accepted_nd->device, nullptr);\n      }\n    }\n    {\n      // One step draft for the following steps\n\n      // Gather hidden states for the last accepted tokens.\n      // Use the function and the workspace of the verify model because the information about the\n      // hidden states is not available in the draft model for medusa.\n      hidden_states = models_[0]->GatherHiddenStates(hidden_states, last_accepted_hidden_positions,\n                                                     &model_workspaces_[0].hidden_states);\n\n      std::vector<int> input_tokens;\n      Array<RequestModelState> mstates;\n      input_tokens.reserve(num_rsentries);\n      mstates.reserve(num_rsentries);\n      for (const RequestStateEntry& rsentry : rsentries) {\n        mstates.push_back(rsentry->mstates[draft_model_id_]);\n      }\n      for (int i = 0; i < num_rsentries; ++i) {\n        TVM_FFI_ICHECK(!mstates[i]->committed_tokens.empty());\n        input_tokens.push_back(mstates[i]->committed_tokens.back().GetTokenId());\n      }\n\n      Array<Tensor> multi_step_logits{nullptr};  // for medusa output\n      if (engine_config_->speculative_mode == SpeculativeMode::kEagle) {\n        // - Compute embeddings.\n        RECORD_EVENT(trace_recorder_, request_ids, \"start proposal embedding\");\n        embeddings = models_[draft_model_id_]->TokenEmbed(\n            {IntTuple{input_tokens.begin(), input_tokens.end()}});\n        RECORD_EVENT(trace_recorder_, request_ids, \"finish proposal embedding\");\n\n        // - Invoke model decode.\n        RECORD_EVENT(trace_recorder_, request_ids, \"start proposal decode\");\n        ObjectRef fused_embedding_hidden_states = models_[draft_model_id_]->FuseEmbedHidden(\n            embeddings, hidden_states, /*batch_size*/ num_rsentries, /*seq_len*/ 1);\n        hidden_states = models_[draft_model_id_]->BatchDecodeToLastHidden(\n            fused_embedding_hidden_states, request_internal_ids);\n\n        int lm_head_model_id = models_[draft_model_id_]->CanGetLogits() ? draft_model_id_ : 0;\n        logits = models_[lm_head_model_id]->GetLogits(hidden_states);\n        RECORD_EVENT(trace_recorder_, request_ids, \"finish proposal decode\");\n        TVM_FFI_ICHECK_EQ(logits->ndim, 2);\n        TVM_FFI_ICHECK_EQ(logits->shape[0], num_rsentries);\n      } else if (engine_config_->speculative_mode == SpeculativeMode::kMedusa) {\n        multi_step_logits = models_[draft_model_id_]->GetMultiStepLogits(hidden_states);\n      }\n\n      // Fill range [0, num_rsentries) into `sample_indices`.\n      std::vector<int> sample_indices(num_rsentries);\n      std::iota(sample_indices.begin(), sample_indices.end(), 0);\n\n      if (engine_config_->speculative_mode == SpeculativeMode::kEagle) {\n        const auto& [renormalized_probs, sample_results] = ApplyLogitProcessorAndSample(\n            logit_processor_, sampler_, logits, generation_cfg, request_ids, mstates, rngs,\n            sample_indices, generation_cfg, request_ids, sample_indices);\n        UpdateRequestStatesWithDraftProposals(mstates, sample_results, draft_model_id_,\n                                              renormalized_probs, hidden_states, estate);\n      } else if (engine_config_->speculative_mode == SpeculativeMode::kMedusa) {\n        TVM_FFI_ICHECK_NE(estate->spec_draft_length, 0);\n        for (int draft_id = 0; draft_id < estate->spec_draft_length; draft_id++) {\n          const auto& [renormalized_probs, sample_results] = ApplyLogitProcessorAndSample(\n              logit_processor_, sampler_, multi_step_logits[draft_id], generation_cfg, request_ids,\n              mstates, rngs, sample_indices, generation_cfg, request_ids, sample_indices);\n          UpdateRequestStatesWithDraftProposals(mstates, sample_results, draft_model_id_,\n                                                renormalized_probs, hidden_states, estate);\n        }\n      }\n    }\n    // reset num_tokens_for_next_decode\n    for (const RequestStateEntry& rsentry : rsentries) {\n      rsentry->mstates[verify_model_id_]->num_tokens_for_next_decode = 0;\n      rsentry->mstates[draft_model_id_]->num_tokens_for_next_decode = 0;\n    }\n    auto tend = std::chrono::high_resolution_clock::now();\n    double elapsed_time = static_cast<double>((tend - tstart).count()) / 1e9;\n    estate->metrics.engine_decode_time_sum += elapsed_time;\n    estate->metrics.UpdateVerifyTimeByBatchSize(cum_verify_lengths.back(), elapsed_time);\n\n    return estate->running_queue;\n  }\n\n private:\n  struct DraftRequestStateEntries {\n    /*! \\brief The request state entries to verify. */\n    Array<RequestStateEntry> draft_rsentries;\n    /*! \\brief The draft length of each request state. */\n    std::vector<int> draft_lengths;\n    /*! \\brief The total draft length. */\n    int total_draft_length;\n  };\n\n  /*!\n   * \\brief Decide whether to run verify for the draft of each request.\n   * \\param estate The engine state.\n   * \\return The drafts to verify, together with their respective\n   * state and input length.\n   */\n  DraftRequestStateEntries GetDraftsToVerify(EngineState estate) {\n    std::vector<int> draft_lengths;\n    int total_draft_length = 0;\n    int total_required_pages = 0;\n    int num_available_pages = models_[verify_model_id_]->GetNumAvailablePages();\n\n    // Preempt the request state entries that cannot fit the large model for verification.\n    std::vector<RequestStateEntry> running_rsentries = estate->GetRunningRequestStateEntries();\n    std::vector<int> num_page_requirement;\n    num_page_requirement.reserve(running_rsentries.size());\n    for (const RequestStateEntry& rsentry : running_rsentries) {\n      int draft_length = rsentry->mstates[draft_model_id_]->draft_output_tokens.size();\n      int num_require_pages = (draft_length + engine_config_->kv_cache_page_size - 1) /\n                              engine_config_->kv_cache_page_size;\n      draft_lengths.push_back(draft_length);\n      num_page_requirement.push_back(num_require_pages);\n      total_draft_length += draft_length;\n      total_required_pages += num_require_pages;\n    }\n    while (!CanVerify(total_required_pages)) {\n      if (estate->prefix_cache->TryFreeMemory()) continue;\n      RequestStateEntry preempted = PreemptLastRunningRequestStateEntry(\n          estate, models_, draft_token_workspace_manager_, trace_recorder_);\n      if (preempted.same_as(running_rsentries.back())) {\n        total_draft_length -= draft_lengths.back();\n        total_required_pages -= num_page_requirement.back();\n        draft_lengths.pop_back();\n        num_page_requirement.pop_back();\n        running_rsentries.pop_back();\n      }\n    }\n\n    return {running_rsentries, draft_lengths, total_draft_length};\n  }\n\n  bool CanVerify(int num_required_pages) {\n    int num_available_pages = models_[0]->GetNumAvailablePages();\n    return num_required_pages <= num_available_pages;\n  }\n\n  void UpdateRequestStatesWithDraftProposals(const Array<RequestModelState>& mstates,\n                                             const std::vector<SampleResult>& sample_results,\n                                             int model_id, const Tensor& renormalized_probs,\n                                             const ObjectRef& hidden_states_for_sample,\n                                             EngineState estate) {\n    draft_token_workspace_manager_->AllocSlots(mstates.size(), &draft_token_slots_);\n    models_[0]->ScatterDraftProbs(renormalized_probs, draft_token_slots_,\n                                  &model_workspaces_[0].draft_probs_storage);\n    if (engine_config_->speculative_mode == SpeculativeMode::kEagle &&\n        estate->spec_draft_length > 1) {\n      models_[0]->ScatterHiddenStates(hidden_states_for_sample, draft_token_slots_,\n                                      &model_workspaces_[0].draft_hidden_states_storage);\n    }\n    for (int i = 0; i < static_cast<int>(mstates.size()); ++i) {\n      int64_t parent_idx = static_cast<int64_t>(mstates[i]->draft_output_tokens.size()) - 1;\n      mstates[i]->AddDraftToken(sample_results[i], draft_token_slots_[i], parent_idx);\n    }\n  }\n  /*!\n   * \\brief The model to run decode in. When there are multiple\n   * models, the `Step` function of the created action will not take effect.\n   */\n  Array<Model> models_;\n  /*! \\brief The logit processor. */\n  LogitProcessor logit_processor_;\n  /*! \\brief The sampler to sample new tokens. */\n  Sampler sampler_;\n  /*! \\brief Workspace of each model. */\n  std::vector<ModelWorkspace> model_workspaces_;\n  /*! \\brief The draft token workspace manager. */\n  DraftTokenWorkspaceManager draft_token_workspace_manager_;\n  /*! \\brief The engine config. */\n  EngineConfig engine_config_;\n  /*! \\brief Event trace recorder. */\n  Optional<EventTraceRecorder> trace_recorder_;\n  /*! \\brief Random number generator. */\n  RandomGenerator& rng_;\n  /*! \\brief The ids of verify/draft models. */\n  const int verify_model_id_ = 0;\n  const int draft_model_id_ = 1;\n  const float eps_ = 1e-5;\n  /*! \\brief Temporary buffer to store the slots of the current draft tokens */\n  std::vector<int> draft_token_slots_;\n};\n\nEngineAction EngineAction::EagleBatchVerify(\n    Array<Model> models, LogitProcessor logit_processor, Sampler sampler,\n    std::vector<ModelWorkspace> model_workspaces,\n    DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config,\n    Optional<EventTraceRecorder> trace_recorder) {\n  return EngineAction(tvm::ffi::make_object<EagleBatchVerifyActionObj>(\n      std::move(models), std::move(logit_processor), std::move(sampler),\n      std::move(model_workspaces), std::move(draft_token_workspace_manager),\n      std::move(engine_config), std::move(trace_recorder)));\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/engine_actions/eagle_new_request_prefill.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_actions/eagle_new_request_prefill.cc\n */\n\n#include \"../sampler/sampler.h\"\n#include \"batch_prefill_base.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\n/*!\n * \\brief The action that prefills requests in the `waiting_queue` of\n * the engine state.\n */\nclass EagleNewRequestPrefillActionObj : public BatchPrefillBaseActionObj {\n public:\n  explicit EagleNewRequestPrefillActionObj(Array<Model> models, LogitProcessor logit_processor,\n                                           Sampler sampler,\n                                           std::vector<ModelWorkspace> model_workspaces,\n                                           DraftTokenWorkspaceManager draft_token_workspace_manager,\n                                           EngineConfig engine_config,\n                                           std::vector<tvm::ffi::json::Object> model_configs,\n                                           Optional<EventTraceRecorder> trace_recorder)\n      : BatchPrefillBaseActionObj(std::move(models), std::move(engine_config),\n                                  std::move(model_configs), std::move(trace_recorder)),\n        logit_processor_(std::move(logit_processor)),\n        sampler_(std::move(sampler)),\n        model_workspaces_(std::move(model_workspaces)),\n        draft_token_workspace_manager_(std::move(draft_token_workspace_manager)) {}\n\n  Array<Request> Step(EngineState estate) final {\n    // - Find the requests in `waiting_queue` that can prefill in this step.\n    std::vector<PrefillInput> prefill_inputs;\n    {\n      NVTXScopedRange nvtx_scope(\"NewRequestPrefill getting requests\");\n      prefill_inputs = GetRequestStateEntriesToPrefill(estate);\n      if (prefill_inputs.empty()) {\n        return {};\n      }\n    }\n\n    int num_rsentries = prefill_inputs.size();\n    {\n      NVTXScopedRange nvtx_scope(\"NewRequestPrefill matching prefix\");\n      for (int i = 0; i < num_rsentries; ++i) {\n        MatchPrefixCache(estate, &prefill_inputs[i]);\n      }\n    }\n\n    auto tstart = std::chrono::high_resolution_clock::now();\n\n    // - Update status of request states from pending to alive.\n    Array<String> request_ids;\n    std::vector<RequestState> rstates_of_entries;\n    std::vector<RequestStateStatus> status_before_prefill;\n    UpdateRequestToAlive(prefill_inputs, estate, &request_ids, &rstates_of_entries,\n                         &status_before_prefill);\n\n    // - Get embedding and run prefill for each model.\n    std::vector<int> prefill_lengths;\n    prefill_lengths.resize(/*size=*/num_rsentries, /*value=*/-1);\n    ObjectRef hidden_states_for_input{nullptr};\n    ObjectRef hidden_states_for_sample{nullptr};\n    Tensor logits_for_sample{nullptr};\n    // A map used to record the entry and child_idx pair needed to fork sequence.\n    // The base model (id 0) should record all the pairs and all the small models\n    // fork sequences according to this map.\n    std::unordered_map<int, std::unordered_set<int>> fork_rsentry_child_map;\n    std::vector<bool> extra_prefill_tokens;\n    prefill_lengths.resize(/*size=*/num_rsentries, /*value=*/false);\n    for (int model_id = 0; model_id < static_cast<int>(models_.size()); ++model_id) {\n      std::vector<int64_t> request_internal_ids;\n      request_internal_ids.reserve(num_rsentries);\n      ObjectRef embeddings = model_workspaces_[model_id].embeddings;\n      int cum_prefill_length = 0;\n      bool single_input =\n          num_rsentries == 1 && prefill_inputs[0].rsentry->mstates[model_id]->inputs.size() == 1;\n      for (int i = 0; i < num_rsentries; ++i) {\n        const RequestStateEntry& rsentry = prefill_inputs[i].rsentry;\n        RequestModelState mstate = rsentry->mstates[model_id];\n        TVM_FFI_ICHECK(mstate->draft_output_tokens.empty());\n        TVM_FFI_ICHECK(mstate->draft_token_slots.empty());\n        if (status_before_prefill[i] == RequestStateStatus::kPending) {\n          if (!estate->prefix_cache->HasSequence(mstate->internal_id)) {\n            // Add the sequence to the model, or fork the sequence from its parent.\n            // If the sequence is already in prefix cache, it has also been added/forked in the\n            // KVCache.\n            if (rsentry->parent_idx == -1) {\n              models_[model_id]->AddNewSequence(mstate->internal_id);\n            } else {\n              models_[model_id]->ForkSequence(rstates_of_entries[i]\n                                                  ->entries[rsentry->parent_idx]\n                                                  ->mstates[model_id]\n                                                  ->internal_id,\n                                              mstate->internal_id);\n            }\n          }\n          // Enable sliding window for the sequence if it is not a parent.\n          if (rsentry->child_indices.empty()) {\n            models_[model_id]->EnableSlidingWindowForSeq(mstate->internal_id);\n          }\n          // Shift the input tokens by 1 for eagle models.\n          if (model_id == 0) {\n            for (int j = 1; j < static_cast<int>(models_.size()); ++j) {\n              TVM_FFI_ICHECK(rsentry->mstates[j]->inputs.size());\n              TokenData token_data = Downcast<TokenData>(rsentry->mstates[j]->inputs[0]);\n              rsentry->mstates[j]->inputs.Set(\n                  0, TokenData(\n                         IntTuple(token_data->token_ids.begin() + 1, token_data->token_ids.end())));\n            }\n          }\n        }\n        request_internal_ids.push_back(mstate->internal_id);\n\n        if (engine_config_->speculative_mode == SpeculativeMode::kMedusa && model_id > 0) {\n          // Embedding is only needed for the base model in Medusa.\n          continue;\n        }\n        auto [input_data, input_length] =\n            ChunkPrefillInputData(mstate, prefill_inputs[i].max_prefill_length);\n        if (prefill_lengths[i] == -1) {\n          prefill_lengths[i] = input_length;\n        } else {\n          TVM_FFI_ICHECK_EQ(prefill_lengths[i], input_length);\n        }\n        mstate->num_prefilled_tokens += input_length;\n\n        RECORD_EVENT(trace_recorder_, prefill_inputs[i].rsentry->request->id, \"start embedding\");\n        // Speculative models shift left the input tokens by 1 when base model has committed tokens.\n        // Note: for n > 1 cases Eagle doesn't work because parent entry doesn't shift input tokens.\n        for (int j = 0; j < static_cast<int>(input_data.size()); ++j) {\n          if (model_id == 0) {\n            mstate->prefilled_inputs.push_back(input_data[j]);\n          }\n          embeddings = input_data[j]->GetEmbedding(\n              models_[model_id],\n              /*dst=*/!single_input ? &model_workspaces_[model_id].embeddings : nullptr,\n              /*offset=*/cum_prefill_length);\n          cum_prefill_length += input_data[j]->GetLength();\n        }\n        RECORD_EVENT(trace_recorder_, rsentry->request->id, \"finish embedding\");\n      }\n\n      RECORD_EVENT(trace_recorder_, request_ids, \"start prefill\");\n\n      Array<Tensor> multi_step_logits{nullptr};\n\n      if (model_id == 0 || engine_config_->speculative_mode == SpeculativeMode::kEagle) {\n        ObjectRef embedding_or_hidden_states{nullptr};\n        if (model_id == 0) {\n          embedding_or_hidden_states = embeddings;\n        } else {\n          embedding_or_hidden_states =\n              models_[model_id]->FuseEmbedHidden(embeddings, hidden_states_for_input,\n                                                 /*batch_size*/ 1, /*seq_len*/ cum_prefill_length);\n        }\n        // hidden_states: (b * s, h)\n        ObjectRef hidden_states = models_[model_id]->BatchPrefillToLastHidden(\n            embedding_or_hidden_states, request_internal_ids, prefill_lengths);\n        RECORD_EVENT(trace_recorder_, request_ids, \"finish prefill\");\n\n        if (model_id == 0) {\n          // We only need to sample for model 0 in prefill.\n          hidden_states_for_input = hidden_states;\n\n          // - Commit the prefix cache changes from previous round of action.\n          // Note: we commit prefix cache changes here to overlap this commit with the GPU\n          // execution.\n          estate->prefix_cache->CommitSequenceExtention();\n        }\n\n        // Whether to use base model to get logits.\n        int sample_model_id = !models_[model_id]->CanGetLogits() ? 0 : model_id;\n\n        std::vector<int> logit_positions;\n        {\n          // Prepare the logit positions\n          logit_positions.reserve(prefill_lengths.size());\n          int total_len = 0;\n          for (int i = 0; i < prefill_lengths.size(); ++i) {\n            total_len += prefill_lengths[i];\n            logit_positions.push_back(total_len - 1);\n          }\n        }\n        // hidden_states_for_sample: (b * s, h)\n        hidden_states_for_sample = models_[sample_model_id]->GatherHiddenStates(\n            hidden_states, logit_positions, &model_workspaces_[model_id].hidden_states);\n        // logits_for_sample: (b * s, v)\n        logits_for_sample = models_[sample_model_id]->GetLogits(hidden_states_for_sample);\n      } else if (engine_config_->speculative_mode == SpeculativeMode::kMedusa) {\n        // Note: spec_draft_length in engine config has to be match the model config in Medusa.\n        multi_step_logits = models_[model_id]->GetMultiStepLogits(hidden_states_for_sample);\n      } else {\n        LOG(FATAL) << \"unreachable\";\n      }\n\n      Array<String> child_request_ids;\n      // - Prepare the configurations for the sampler.\n      //   For prefill_inputs which have children, sample\n      //   one token for each rstate that is depending.\n      //   Otherwise, sample a token for the current rstate.\n      std::vector<int> child_sample_indices;\n      std::vector<RequestStateEntry> rsentries_for_sample;\n      std::vector<RandomGenerator*> rngs;\n      std::vector<bool> rsentry_activated;\n      Array<GenerationConfig> child_generation_cfg;\n      child_sample_indices.reserve(num_rsentries);\n      child_generation_cfg.reserve(num_rsentries);\n      child_request_ids.reserve(num_rsentries);\n      rsentries_for_sample.reserve(num_rsentries);\n      rngs.reserve(num_rsentries);\n      rsentry_activated.reserve(num_rsentries);\n      for (int i = 0; i < num_rsentries; ++i) {\n        const RequestStateEntry& rsentry = prefill_inputs[i].rsentry;\n        // No sample for rsentries with remaining inputs.\n        if (!rsentry->mstates[0]->inputs.empty()) {\n          continue;\n        }\n\n        int remaining_num_child_to_activate = prefill_inputs[i].num_child_to_activate;\n        for (int child_idx : rsentry->child_indices) {\n          // Only use base model to judge if we need to add child entries.\n          if ((rstates_of_entries[i]->entries[child_idx]->status == RequestStateStatus::kPending &&\n                   rstates_of_entries[i]\n                       ->entries[child_idx]\n                       ->mstates[0]\n                       ->committed_tokens.empty() ||\n               fork_rsentry_child_map[i].count(child_idx))) {\n            // If rstates_of_entries[i]->entries[child_idx] has no committed token,\n            // the prefill of the current rsentry will unblock\n            // rstates_of_entries[i]->entries[child_idx],\n            // and thus we want to sample a token for rstates_of_entries[i]->entries[child_idx].\n            fork_rsentry_child_map[i].insert(child_idx);\n            child_sample_indices.push_back(i);\n            rsentries_for_sample.push_back(rstates_of_entries[i]->entries[child_idx]);\n            child_request_ids.push_back(rsentry->request->id);\n            child_generation_cfg.push_back(rsentry->request->generation_cfg);\n            rngs.push_back(&rstates_of_entries[i]->entries[child_idx]->rng);\n\n            // We only fork the first `num_child_to_activate` children.\n            // The children not being forked will be forked via later prefills.\n            // Usually `num_child_to_activate` is the same as the number of children.\n            // But it can be fewer subject to the KV cache max num sequence limit.\n            if (remaining_num_child_to_activate == 0) {\n              rsentry_activated.push_back(false);\n              continue;\n            }\n            rsentry_activated.push_back(true);\n            --remaining_num_child_to_activate;\n            if (model_id == 0) {\n              TVM_FFI_ICHECK(rstates_of_entries[i]->entries[child_idx]->status ==\n                             RequestStateStatus::kPending);\n              rstates_of_entries[i]->entries[child_idx]->status = RequestStateStatus::kAlive;\n            }\n            int64_t child_internal_id =\n                rstates_of_entries[i]->entries[child_idx]->mstates[model_id]->internal_id;\n            models_[model_id]->ForkSequence(rsentry->mstates[model_id]->internal_id,\n                                            child_internal_id);\n            // Enable sliding window for the child sequence if the child is not a parent.\n            if (rstates_of_entries[i]->entries[child_idx]->child_indices.empty()) {\n              models_[model_id]->EnableSlidingWindowForSeq(child_internal_id);\n            }\n          }\n        }\n        if (rsentry->child_indices.empty()) {\n          // If rsentry has no child, we sample a token for itself.\n          child_sample_indices.push_back(i);\n          rsentries_for_sample.push_back(rsentry);\n          child_request_ids.push_back(rsentry->request->id);\n          child_generation_cfg.push_back(rsentry->request->generation_cfg);\n          rngs.push_back(&rsentry->rng);\n          rsentry_activated.push_back(true);\n        }\n      }\n\n      // - Prepare input for logit processor.\n      TVM_FFI_ICHECK(logits_for_sample.defined());\n      Array<GenerationConfig> generation_cfg;\n      Array<RequestModelState> mstates_for_logitproc;\n      std::vector<int> sample_indices(num_rsentries);\n      generation_cfg.reserve(num_rsentries);\n      mstates_for_logitproc.reserve(num_rsentries);\n      std::iota(sample_indices.begin(), sample_indices.end(), 0);\n      for (int i = 0; i < num_rsentries; ++i) {\n        generation_cfg.push_back(prefill_inputs[i].rsentry->request->generation_cfg);\n        mstates_for_logitproc.push_back(prefill_inputs[i].rsentry->mstates[model_id]);\n      }\n      if (model_id == 0 || engine_config_->speculative_mode == SpeculativeMode::kEagle) {\n        const auto& [renormalized_probs, sample_results] = ApplyLogitProcessorAndSample(\n            logit_processor_, sampler_, logits_for_sample, generation_cfg, request_ids,\n            mstates_for_logitproc, rngs, sample_indices, child_generation_cfg, child_request_ids,\n            child_sample_indices);\n        if (model_id == 0) {\n          UpdateRequestStateEntriesWithSampleResults(rsentries_for_sample, rsentry_activated,\n                                                     sample_results);\n          // Add the sampled token as an input of the eagle models.\n          if (engine_config_->speculative_mode == SpeculativeMode::kEagle) {\n            for (int i = 0; i < static_cast<int>(rsentries_for_sample.size()); ++i) {\n              for (int mid = 1; mid < static_cast<int>(models_.size()); ++mid) {\n                TokenData token_data =\n                    Downcast<TokenData>(rsentries_for_sample[i]->mstates[mid]->inputs.back());\n                std::vector<int32_t> token_ids = {token_data->token_ids.begin(),\n                                                  token_data->token_ids.end()};\n                token_ids.push_back(sample_results[i].GetTokenId());\n                int ninputs =\n                    static_cast<int>(rsentries_for_sample[i]->mstates[mid]->inputs.size());\n                rsentries_for_sample[i]->mstates[mid]->inputs.Set(\n                    ninputs - 1, TokenData(IntTuple(token_ids.begin(), token_ids.end())));\n              }\n            }\n          }\n        } else {\n          // - Slice and save hidden_states_for_sample\n          UpdateRequestStatesWithDraftProposals(rsentries_for_sample, sample_results, model_id,\n                                                renormalized_probs, hidden_states_for_sample,\n                                                estate, child_sample_indices);\n        }\n      } else if (engine_config_->speculative_mode == SpeculativeMode::kMedusa) {\n        TVM_FFI_ICHECK_NE(estate->spec_draft_length, 0);\n        for (int draft_id = 0; draft_id < estate->spec_draft_length; ++draft_id) {\n          const auto& [renormalized_probs, sample_results] = ApplyLogitProcessorAndSample(\n              logit_processor_, sampler_, multi_step_logits[draft_id], generation_cfg, request_ids,\n              mstates_for_logitproc, rngs, sample_indices, child_generation_cfg, child_request_ids,\n              child_sample_indices);\n\n          UpdateRequestStatesWithDraftProposals(\n              rsentries_for_sample, sample_results, model_id, renormalized_probs,\n              /*hidden_states=*/ObjectRef{nullptr}, estate, child_sample_indices);\n        }\n      }\n    }\n\n    auto tend = std::chrono::high_resolution_clock::now();\n    estate->metrics.engine_prefill_time_sum += static_cast<double>((tend - tstart).count()) / 1e9;\n\n    std::vector<Request> processed_requests =\n        RemoveProcessedRequests(prefill_inputs, estate, rstates_of_entries);\n    estate->running_rsentries_changed = true;\n    return processed_requests;\n  }\n\n  void UpdateRequestStatesWithDraftProposals(\n      const std::vector<RequestStateEntry>& rsentries_for_sample,\n      const std::vector<SampleResult>& sample_results, int model_id,\n      const Tensor& renormalized_probs, const ObjectRef& hidden_states_for_sample,\n      EngineState estate, const std::vector<int>& sample_indices) {\n    std::vector<int> reuse_count(renormalized_probs->shape[0], 0);\n    for (int i = 0; i < static_cast<int>(sample_indices.size()); ++i) {\n      // The same probability may be sampled multiple times.\n      reuse_count[sample_indices[i]]++;\n    }\n    draft_token_workspace_manager_->AllocSlots(renormalized_probs->shape[0], reuse_count,\n                                               &draft_token_slots_);\n\n    models_[0]->ScatterDraftProbs(renormalized_probs, draft_token_slots_,\n                                  &model_workspaces_[0].draft_probs_storage);\n    if (engine_config_->speculative_mode == SpeculativeMode::kEagle &&\n        estate->spec_draft_length > 1) {\n      models_[0]->ScatterHiddenStates(hidden_states_for_sample, draft_token_slots_,\n                                      &model_workspaces_[0].draft_hidden_states_storage);\n    }\n    for (int i = 0; i < static_cast<int>(rsentries_for_sample.size()); ++i) {\n      int parent_idx =\n          rsentries_for_sample[i]->mstates[model_id]->draft_output_tokens.empty()\n              ? -1\n              : rsentries_for_sample[i]->mstates[model_id]->draft_output_tokens.size() - 1;\n      rsentries_for_sample[i]->mstates[model_id]->AddDraftToken(\n          sample_results[i], draft_token_slots_[sample_indices[i]], parent_idx);\n    }\n  }\n\n private:\n  /*! \\brief The logit processor. */\n  LogitProcessor logit_processor_;\n  /*! \\brief The sampler to sample new tokens. */\n  Sampler sampler_;\n  /*! \\brief Workspace of each model. */\n  std::vector<ModelWorkspace> model_workspaces_;\n  /*! \\brief The draft token workspace manager. */\n  DraftTokenWorkspaceManager draft_token_workspace_manager_;\n  /*! \\brief Temporary buffer to store the slots of the current draft tokens */\n  std::vector<int> draft_token_slots_;\n\n  /*!\n   * \\brief Match the request state entry with prefix cache, to skip prefilling common prefix\n   * tokens. If the request state entry is not added to KVCache yet, this method will add/fork the\n   * request in the KVCache, depending on the matching result from prefix cache.\n   * \\param estate The engine state.\n   * \\param[in, out] input The prefill input to be matched and updated.\n   */\n  int MatchPrefixCache(EngineState estate, PrefillInput* input) final {\n    RequestStateEntry rsentry = input->rsentry;\n    if (estate->prefix_cache->Mode() == PrefixCacheMode::kDisable) {\n      return 0;\n    }\n    if (rsentry->parent_idx == -1 && rsentry->status == RequestStateStatus::kPending &&\n        !estate->prefix_cache->HasSequence(rsentry->mstates[0]->internal_id)) {\n      std::vector<int32_t> tokens = GetConcatPrefillInputData(rsentry->mstates[0]);\n      if (tokens.empty()) {\n        // If the RequestStateEntry is of empty input data, or not fully tokenized, do nothing\n        // and return.\n        return 0;\n      }\n      PrefixCacheMatchedResult result = estate->prefix_cache->InsertSequence(\n          rsentry->mstates[0]->internal_id, tokens, models_[0]->GetSlidingWindowSize(),\n          models_[0]->GetAttentionSinkSize());\n      if (result.prefilled_offset == 0) {\n        // Add new sequence.\n        // Note: Almost same as without eagle speculative decoding. But in prefill step, the\n        // prefill embedding input in draft model will be shifted one token, compared to the base\n        // model. Just the new sequence without prefix cache. Here we merely add the new sequence\n        // in advance of prefill step.\n        TVM_FFI_ICHECK_EQ(result.forked_seq_id, -1);\n        TVM_FFI_ICHECK_EQ(result.reused_seq_id, -1);\n        TVM_FFI_ICHECK_EQ(result.reused_seq_pop_last_tokens, 0);\n        for (int i = 0; i < models_.size(); ++i) {\n          models_[i]->AddNewSequence(rsentry->mstates[0]->internal_id);\n          // Enable sliding window for the sequence if it is not a parent.\n          if (rsentry->child_indices.empty()) {\n            models_[i]->EnableSlidingWindowForSeq(rsentry->mstates[0]->internal_id);\n          }\n        }\n      } else {\n        if (result.forked_seq_id != -1) {\n          // Fork from active sequence\n          // Note: Due to the shifted KVCache between base model and draft model, we do a trick\n          // over forking sequence:\n          // For example. we have a sequence of [0, 1, 2] in base model KVCache, and the\n          // corresponding sequence of [1, 2, 3] in draft model KVCache, where token [3] was\n          // sampled from base model, but not appended in base model KVCache. Then we get a new\n          // sequence [0, 1, 4] to prefill. Although the new sequence matches first two tokens\n          // with the sequence [0, 1, 2], we have to fork from the first token 0, not the second\n          // token 1. Because if we fork from the second token, we will prefill like: Base model:\n          // [0, 1] + prefill([4]) => [5] Draft model: [1] + prefill([4, 5]) The lengths to\n          // prefill is different between base model and draft model, which is illegal. So we roll\n          // back one token in prefix cache to fork from the first token. Then the prefill will be\n          // like: Base model: [0] + prefill([1, 4]) => [5] Draft model: [1] + prefill([4, 5]) And\n          // we shift the input prefill data as other new sequence, to avoid double prefilling\n          // token 1, and make the prefill length aligned between base model and draft model.\n          TVM_FFI_ICHECK_EQ(result.reused_seq_id, -1);\n          TVM_FFI_ICHECK_EQ(result.reused_seq_pop_last_tokens, 0);\n          estate->prefix_cache->RollBackSequence(rsentry->mstates[0]->internal_id, 1);\n          for (int i = 0; i < models_.size(); ++i) {\n            models_[i]->ForkSequence(result.forked_seq_id, rsentry->mstates[0]->internal_id,\n                                     result.prefilled_offset - 1);\n            // Enable sliding window for the sequence if it is not a parent.\n            if (rsentry->child_indices.empty()) {\n              models_[i]->EnableSlidingWindowForSeq(rsentry->mstates[0]->internal_id);\n            }\n          }\n        } else {\n          // Reuse recycling sequence\n          // Note: The processing for reusing recycling sequence is like forking sequence. And we\n          // also roll back one token due to the reason mentioned above.\n          TVM_FFI_ICHECK_EQ(result.forked_seq_id, -1);\n          estate->id_manager.RecycleId(rsentry->mstates[0]->internal_id);\n          for (int i = 0; i < rsentry->mstates.size(); ++i) {\n            rsentry->mstates[i]->internal_id = result.reused_seq_id;\n          }\n          estate->prefix_cache->RollBackSequence(rsentry->mstates[0]->internal_id, 1);\n          for (int i = 0; i < models_.size(); ++i) {\n            models_[i]->PopNFromKVCache(rsentry->mstates[0]->internal_id,\n                                        result.reused_seq_pop_last_tokens + 1);\n          }\n          result.prefilled_offset -= 1;\n        }\n      }\n      // Pop matched prefix\n      if (result.prefilled_offset > 0) {\n        for (int i = 0; i < rsentry->mstates.size(); ++i) {\n          PopPrefillInputData(rsentry->mstates[i], result.prefilled_offset);\n        }\n      }\n      // Update max prefill length\n      input->max_prefill_length =\n          std::min(input->max_prefill_length, rsentry->mstates[0]->GetInputLength());\n      return result.prefilled_offset - 1;\n    }\n    return 0;\n  }\n};\n\nEngineAction EngineAction::EagleNewRequestPrefill(\n    Array<Model> models, LogitProcessor logit_processor, Sampler sampler,\n    std::vector<ModelWorkspace> model_workspaces,\n    DraftTokenWorkspaceManager draft_token_workspace_manager, EngineConfig engine_config,\n    std::vector<tvm::ffi::json::Object> model_configs,\n    Optional<EventTraceRecorder> trace_recorder) {\n  return EngineAction(tvm::ffi::make_object<EagleNewRequestPrefillActionObj>(\n      std::move(models), std::move(logit_processor), std::move(sampler),\n      std::move(model_workspaces), std::move(draft_token_workspace_manager),\n      std::move(engine_config), std::move(model_configs), std::move(trace_recorder)));\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/engine_actions/new_request_prefill.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_actions/new_request_prefill.cc\n */\n\n#include \"../sampler/sampler.h\"\n#include \"batch_prefill_base.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\n/*!\n * \\brief The action that prefills requests in the `waiting_queue` of\n * the engine state.\n */\nclass NewRequestPrefillActionObj : public BatchPrefillBaseActionObj {\n public:\n  explicit NewRequestPrefillActionObj(Array<Model> models, LogitProcessor logit_processor,\n                                      Sampler sampler, std::vector<ModelWorkspace> model_workspaces,\n                                      EngineConfig engine_config,\n                                      std::vector<tvm::ffi::json::Object> model_configs,\n                                      Optional<EventTraceRecorder> trace_recorder)\n      : BatchPrefillBaseActionObj(std::move(models), std::move(engine_config),\n                                  std::move(model_configs), std::move(trace_recorder)),\n        logit_processor_(std::move(logit_processor)),\n        sampler_(std::move(sampler)),\n        model_workspaces_(std::move(model_workspaces)) {}\n\n  Array<Request> Step(EngineState estate) final {\n    // - Find the requests in `waiting_queue` that can prefill in this step.\n    std::vector<PrefillInput> prefill_inputs;\n    {\n      NVTXScopedRange nvtx_scope(\"NewRequestPrefill getting requests\");\n      prefill_inputs = GetRequestStateEntriesToPrefill(estate);\n      if (prefill_inputs.empty()) {\n        return {};\n      }\n    }\n\n    int num_rsentries = prefill_inputs.size();\n    {\n      NVTXScopedRange nvtx_scope(\"NewRequestPrefill matching prefix\");\n      for (int i = 0; i < num_rsentries; ++i) {\n        MatchPrefixCache(estate, &prefill_inputs[i]);\n      }\n    }\n\n    auto tstart = std::chrono::high_resolution_clock::now();\n\n    // - Update status of request states from pending to alive.\n    Array<String> request_ids;\n    std::vector<RequestState> rstates_of_entries;\n    std::vector<RequestStateStatus> status_before_prefill;\n    UpdateRequestToAlive(prefill_inputs, estate, &request_ids, &rstates_of_entries,\n                         &status_before_prefill);\n\n    // - Get embedding and run prefill for each model.\n    std::vector<int> prefill_lengths;\n    prefill_lengths.resize(/*size=*/num_rsentries, /*value=*/-1);\n    Tensor logits_for_sample{nullptr};\n    for (int model_id = 0; model_id < static_cast<int>(models_.size()); ++model_id) {\n      std::vector<int64_t> request_internal_ids;\n      request_internal_ids.reserve(num_rsentries);\n      ObjectRef embeddings = model_workspaces_[model_id].embeddings;\n      int cum_prefill_length = 0;\n      bool single_input =\n          num_rsentries == 1 && prefill_inputs[0].rsentry->mstates[model_id]->inputs.size() == 1;\n      std::vector<int64_t> cached_token_data;\n      for (int i = 0; i < num_rsentries; ++i) {\n        const RequestStateEntry& rsentry = prefill_inputs[i].rsentry;\n        RequestModelState mstate = rsentry->mstates[model_id];\n        auto [input_data, input_length] =\n            ChunkPrefillInputData(mstate, prefill_inputs[i].max_prefill_length);\n        if (prefill_lengths[i] == -1) {\n          prefill_lengths[i] = input_length;\n        } else {\n          TVM_FFI_ICHECK_EQ(prefill_lengths[i], input_length);\n        }\n        mstate->num_prefilled_tokens += input_length;\n\n        TVM_FFI_ICHECK(mstate->draft_output_tokens.empty());\n        TVM_FFI_ICHECK(mstate->draft_token_slots.empty());\n        if (status_before_prefill[i] == RequestStateStatus::kPending &&\n            !estate->prefix_cache->HasSequence(mstate->internal_id)) {\n          // Add the sequence to the model, or fork the sequence from its parent.\n          // If the sequence is already in prefix cache, it has also been added/forked in the\n          // KVCache.\n          if (rsentry->parent_idx == -1) {\n            models_[model_id]->AddNewSequence(mstate->internal_id);\n          } else {\n            models_[model_id]->ForkSequence(\n                rstates_of_entries[i]->entries[rsentry->parent_idx]->mstates[model_id]->internal_id,\n                mstate->internal_id);\n          }\n          // Enable sliding window for the sequence if it is not a parent.\n          if (rsentry->child_indices.empty()) {\n            models_[model_id]->EnableSlidingWindowForSeq(mstate->internal_id);\n          }\n        }\n        request_internal_ids.push_back(mstate->internal_id);\n        RECORD_EVENT(trace_recorder_, rsentry->request->id, \"start embedding\");\n        for (int j = 0; j < static_cast<int>(input_data.size()); ++j) {\n          if (!model_id && !prefill_inputs[i].is_decode) {\n            mstate->prefilled_inputs.push_back(input_data[j]);\n          }\n          if (const auto* token_data = input_data[j].as<TokenDataNode>()) {\n            cached_token_data.insert(cached_token_data.end(), token_data->token_ids.begin(),\n                                     token_data->token_ids.end());\n          } else {\n            if (!cached_token_data.empty()) {\n              embeddings = TokenData(cached_token_data)\n                               ->GetEmbedding(models_[model_id],\n                                              /*dst=*/!single_input ? &embeddings : nullptr,\n                                              /*offset=*/cum_prefill_length);\n              cum_prefill_length += cached_token_data.size();\n              cached_token_data.clear();\n            }\n            embeddings = input_data[j]->GetEmbedding(models_[model_id],\n                                                     /*dst=*/!single_input ? &embeddings : nullptr,\n                                                     /*offset=*/cum_prefill_length);\n            cum_prefill_length += input_data[j]->GetLength();\n          }\n        }\n        RECORD_EVENT(trace_recorder_, rsentry->request->id, \"finish embedding\");\n      }\n      if (!cached_token_data.empty()) {\n        embeddings = TokenData(cached_token_data)\n                         ->GetEmbedding(models_[model_id],\n                                        /*dst=*/!single_input ? &embeddings : nullptr,\n                                        /*offset=*/cum_prefill_length);\n        cum_prefill_length += cached_token_data.size();\n        cached_token_data.clear();\n      }\n\n      RECORD_EVENT(trace_recorder_, request_ids, \"start prefill\");\n      Tensor logits =\n          models_[model_id]->BatchPrefill(embeddings, request_internal_ids, prefill_lengths);\n      RECORD_EVENT(trace_recorder_, request_ids, \"finish prefill\");\n      TVM_FFI_ICHECK_EQ(logits->ndim, 3);\n      TVM_FFI_ICHECK_EQ(logits->shape[0], 1);\n      TVM_FFI_ICHECK_EQ(logits->shape[1], num_rsentries);\n\n      if (model_id == 0) {\n        // We only need to sample for model 0 in prefill.\n        logits_for_sample = logits;\n      }\n    }\n\n    // - Update logits.\n    TVM_FFI_ICHECK(logits_for_sample.defined());\n    Array<GenerationConfig> generation_cfg;\n    Array<RequestModelState> mstates_for_logitproc;\n    generation_cfg.reserve(num_rsentries);\n    mstates_for_logitproc.reserve(num_rsentries);\n    for (int i = 0; i < num_rsentries; ++i) {\n      generation_cfg.push_back(prefill_inputs[i].rsentry->request->generation_cfg);\n      mstates_for_logitproc.push_back(prefill_inputs[i].rsentry->mstates[0]);\n    }\n    logits_for_sample = logits_for_sample.CreateView({num_rsentries, logits_for_sample->shape[2]},\n                                                     logits_for_sample->dtype);\n    logit_processor_->InplaceUpdateLogits(logits_for_sample, generation_cfg, mstates_for_logitproc,\n                                          request_ids);\n\n    // - Compute probability distributions.\n    Tensor probs_on_device =\n        logit_processor_->ComputeProbsFromLogits(logits_for_sample, generation_cfg, request_ids);\n\n    // - Commit the prefix cache changes from previous round of action.\n    // Note: we commit prefix cache changes here to overlap this commit with the GPU execution.\n    estate->prefix_cache->CommitSequenceExtention();\n\n    // - Sample tokens.\n    //   For rsentries which have children, sample\n    //   one token for each rstate that is depending.\n    //   Otherwise, sample a token for the current rstate.\n    std::vector<int> sample_indices;\n    std::vector<RequestStateEntry> rsentries_for_sample;\n    std::vector<RandomGenerator*> rngs;\n    std::vector<bool> rsentry_activated;\n    sample_indices.reserve(num_rsentries);\n    rsentries_for_sample.reserve(num_rsentries);\n    rngs.reserve(num_rsentries);\n    rsentry_activated.reserve(num_rsentries);\n    request_ids.clear();\n    generation_cfg.clear();\n    for (int i = 0; i < num_rsentries; ++i) {\n      const RequestStateEntry& rsentry = prefill_inputs[i].rsentry;\n      // No sample for rsentries with remaining inputs.\n      if (!rsentry->mstates[0]->inputs.empty()) {\n        continue;\n      }\n\n      int remaining_num_child_to_activate = prefill_inputs[i].num_child_to_activate;\n      for (int child_idx : rsentry->child_indices) {\n        // If rstates_of_entries[i]->entries[child_idx] has no committed token,\n        // the prefill of the current rsentry will unblock\n        // rstates_of_entries[i]->entries[child_idx],\n        // and thus we want to sample a token for rstates_of_entries[i]->entries[child_idx].\n        if (rstates_of_entries[i]->entries[child_idx]->status != RequestStateStatus::kPending ||\n            !rstates_of_entries[i]->entries[child_idx]->mstates[0]->committed_tokens.empty()) {\n          continue;\n        }\n        sample_indices.push_back(i);\n        rsentries_for_sample.push_back(rstates_of_entries[i]->entries[child_idx]);\n        request_ids.push_back(rsentry->request->id);\n        generation_cfg.push_back(rsentry->request->generation_cfg);\n        rngs.push_back(&rstates_of_entries[i]->entries[child_idx]->rng);\n\n        TVM_FFI_ICHECK(rstates_of_entries[i]->entries[child_idx]->status ==\n                       RequestStateStatus::kPending);\n        // We only fork the first `num_child_to_activate` children.\n        // The children not being forked will be forked via later prefills.\n        // Usually `num_child_to_activate` is the same as the number of children.\n        // But it can be fewer subject to the KV cache max num sequence limit.\n        if (remaining_num_child_to_activate == 0) {\n          rsentry_activated.push_back(false);\n          continue;\n        }\n        rsentry_activated.push_back(true);\n        --remaining_num_child_to_activate;\n        rstates_of_entries[i]->entries[child_idx]->status = RequestStateStatus::kAlive;\n        for (int model_id = 0; model_id < static_cast<int>(models_.size()); ++model_id) {\n          int64_t child_internal_id =\n              rstates_of_entries[i]->entries[child_idx]->mstates[model_id]->internal_id;\n          models_[model_id]->ForkSequence(rsentry->mstates[model_id]->internal_id,\n                                          child_internal_id);\n          // Enable sliding window for the child sequence if the child is not a parent.\n          if (rstates_of_entries[i]->entries[child_idx]->child_indices.empty()) {\n            models_[model_id]->EnableSlidingWindowForSeq(child_internal_id);\n          }\n        }\n      }\n      if (rsentry->child_indices.empty()) {\n        // If rsentry has no child, we sample a token for itself.\n        sample_indices.push_back(i);\n        rsentries_for_sample.push_back(rsentry);\n        request_ids.push_back(rsentry->request->id);\n        generation_cfg.push_back(rsentry->request->generation_cfg);\n        rngs.push_back(&rsentry->rng);\n        rsentry_activated.push_back(true);\n      }\n    }\n    Tensor renormalized_probs = sampler_->BatchRenormalizeProbsByTopP(\n        probs_on_device, sample_indices, request_ids, generation_cfg);\n    std::vector<SampleResult> sample_results = sampler_->BatchSampleTokensWithProbAfterTopP(\n        renormalized_probs, sample_indices, request_ids, generation_cfg, rngs);\n    TVM_FFI_ICHECK_EQ(sample_results.size(), rsentries_for_sample.size());\n\n    // - Update the committed tokens of states.\n    // - If a request is first-time prefilled, set the prefill finish time.\n    UpdateRequestStateEntriesWithSampleResults(rsentries_for_sample, rsentry_activated,\n                                               sample_results);\n\n    auto tend = std::chrono::high_resolution_clock::now();\n    estate->metrics.engine_prefill_time_sum += static_cast<double>((tend - tstart).count()) / 1e9;\n\n    std::vector<Request> processed_requests =\n        RemoveProcessedRequests(prefill_inputs, estate, rstates_of_entries);\n    estate->running_rsentries_changed = true;\n    return processed_requests;\n  }\n\n private:\n  /*! \\brief The logit processor. */\n  LogitProcessor logit_processor_;\n  /*! \\brief The sampler to sample new tokens. */\n  Sampler sampler_;\n  /*! \\brief Workspace of each model. */\n  std::vector<ModelWorkspace> model_workspaces_;\n\n  /*!\n   * \\brief Match the request state entry with prefix cache, to skip prefilling common prefix\n   * tokens. If the request state entry is not added to KVCache yet, this method will add/fork the\n   * request in the KVCache, depending on the matching result from prefix cache.\n   * \\param estate The engine state.\n   * \\param[in, out] input The prefill input to be matched and updated.\n   * \\return The matched length in prefix cache.\n   */\n  int MatchPrefixCache(EngineState estate, PrefillInput* input) final {\n    RequestStateEntry rsentry = input->rsentry;\n    if (estate->prefix_cache->Mode() == PrefixCacheMode::kDisable) {\n      return 0;\n    }\n    if (rsentry->parent_idx == -1 && rsentry->status == RequestStateStatus::kPending &&\n        !estate->prefix_cache->HasSequence(rsentry->mstates[0]->internal_id)) {\n      std::vector<int32_t> tokens = GetConcatPrefillInputData(rsentry->mstates[0]);\n      if (tokens.empty()) {\n        // If the RequestStateEntry is of empty input data, or not fully tokenized, do nothing\n        // and return.\n        return 0;\n      }\n      PrefixCacheMatchedResult result = estate->prefix_cache->InsertSequence(\n          rsentry->mstates[0]->internal_id, tokens, models_[0]->GetSlidingWindowSize(),\n          models_[0]->GetAttentionSinkSize());\n\n      if (result.prefilled_offset == 0) {\n        // Add new sequence\n        TVM_FFI_ICHECK_EQ(result.forked_seq_id, -1);\n        TVM_FFI_ICHECK_EQ(result.reused_seq_id, -1);\n        TVM_FFI_ICHECK_EQ(result.reused_seq_pop_last_tokens, 0);\n        for (Model model : models_) {\n          model->AddNewSequence(rsentry->mstates[0]->internal_id);\n          // Enable sliding window for the sequence if it is not a parent.\n          if (rsentry->child_indices.empty()) {\n            model->EnableSlidingWindowForSeq(rsentry->mstates[0]->internal_id);\n          }\n        }\n      } else {\n        if (result.forked_seq_id != -1) {\n          TVM_FFI_ICHECK_EQ(result.reused_seq_id, -1);\n          TVM_FFI_ICHECK_EQ(result.reused_seq_pop_last_tokens, 0);\n          // Fork from active sequence\n          for (Model model : models_) {\n            model->ForkSequence(result.forked_seq_id, rsentry->mstates[0]->internal_id,\n                                result.prefilled_offset);\n            // Enable sliding window for the sequence if it is not a parent.\n            if (rsentry->child_indices.empty()) {\n              model->EnableSlidingWindowForSeq(rsentry->mstates[0]->internal_id);\n            }\n          }\n        } else {\n          // Reuse recycling sequence\n          TVM_FFI_ICHECK_EQ(result.forked_seq_id, -1);\n          estate->id_manager.RecycleId(rsentry->mstates[0]->internal_id);\n          for (int i = 0; i < rsentry->mstates.size(); ++i) {\n            rsentry->mstates[i]->internal_id = result.reused_seq_id;\n          }\n          if (result.reused_seq_pop_last_tokens > 0) {\n            for (Model model : models_) {\n              model->PopNFromKVCache(rsentry->mstates[0]->internal_id,\n                                     result.reused_seq_pop_last_tokens);\n            }\n          }\n        }\n      }\n      // Pop matched prefix\n      if (result.prefilled_offset) {\n        for (int i = 0; i < rsentry->mstates.size(); ++i) {\n          PopPrefillInputData(rsentry->mstates[i], result.prefilled_offset);\n        }\n      }\n      // Update max prefill length\n      input->max_prefill_length =\n          std::min(input->max_prefill_length, rsentry->mstates[0]->GetInputLength());\n      return result.prefilled_offset;\n    }\n    return 0;\n  }\n};  // namespace serve\n\nEngineAction EngineAction::NewRequestPrefill(Array<Model> models, LogitProcessor logit_processor,\n                                             Sampler sampler,\n                                             std::vector<ModelWorkspace> model_workspaces,\n                                             EngineConfig engine_config,\n                                             std::vector<tvm::ffi::json::Object> model_configs,\n                                             Optional<EventTraceRecorder> trace_recorder) {\n  return EngineAction(tvm::ffi::make_object<NewRequestPrefillActionObj>(\n      std::move(models), std::move(logit_processor), std::move(sampler),\n      std::move(model_workspaces), std::move(engine_config), std::move(model_configs),\n      std::move(trace_recorder)));\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/engine_state.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_state.cc\n */\n#include \"engine_state.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nTVM_FFI_STATIC_INIT_BLOCK() { EngineStateObj::RegisterReflection(); }\n\nEngineState::EngineState() { data_ = tvm::ffi::make_object<EngineStateObj>(); }\n\nvoid EngineStateObj::Reset() {\n  running_queue.clear();\n  waiting_queue.clear();\n  request_states.clear();\n  id_manager.Reset();\n  metrics.Reset();\n  if (prefix_cache.defined()) {\n    prefix_cache->Reset();\n  }\n  running_rsentries_changed = true;\n  postproc_workspace = ActionPostProcessWorkspace();\n}\n\nRequestState EngineStateObj::GetRequestState(Request request) {\n  TVM_FFI_ICHECK(request->rstate != nullptr) << \"The state of the request has not been defined.\";\n  return GetRef<RequestState>(static_cast<RequestStateNode*>(request->rstate));\n}\n\nconst std::vector<RequestStateEntry>& EngineStateObj::GetRunningRequestStateEntries() {\n  if (running_rsentries_changed) {\n    cached_running_rsentries_.clear();\n    for (const Request& request : running_queue) {\n      for (const RequestStateEntry& rsentry : GetRequestState(request)->entries) {\n        // One request entry is considered as running for decode if it is a leaf and has\n        // finished all input prefill.\n        if (rsentry->status == RequestStateStatus::kAlive && rsentry->child_indices.empty() &&\n            rsentry->mstates[0]->inputs.empty()) {\n          cached_running_rsentries_.push_back(rsentry);\n        }\n      }\n    }\n    running_rsentries_changed = false;\n  }\n  return cached_running_rsentries_;\n  //\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/engine_state.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/engine_state.h\n */\n#ifndef MLC_LLM_SERVE_ENGINE_STATE_H_\n#define MLC_LLM_SERVE_ENGINE_STATE_H_\n\n#include <tvm/ffi/string.h>\n\n#include \"config.h\"\n#include \"metrics.h\"\n#include \"prefix_cache.h\"\n#include \"request.h\"\n#include \"request_state.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing namespace tvm::runtime;\n\ntypedef TypedFunction<void(Array<RequestStreamOutput>)> FRequestStreamCallback;\n\n/*! \\brief The manager of internal id for requests in engine. */\nstruct EngineInternalIDManager {\n  std::vector<int64_t> available_ids;\n  int64_t id_cnt = 0;\n\n  /*! \\brief Return an unused id. */\n  int64_t GetNewId() {\n    if (!available_ids.empty()) {\n      int64_t id = available_ids.back();\n      available_ids.pop_back();\n      return id;\n    } else {\n      return id_cnt++;\n    }\n  }\n\n  /*! \\brief Recycle an id. */\n  void RecycleId(int64_t id) { available_ids.push_back(id); }\n\n  /*! \\brief Reset the manager. */\n  void Reset() {\n    available_ids.clear();\n    id_cnt = 0;\n  }\n};\n\n/*! \\brief The data structures used in the action post-process. */\nstruct ActionPostProcessWorkspace {\n  std::vector<RequestStateEntry> finished_rsentries;\n  Array<RequestStreamOutput> callback_delta_outputs;\n};\n\n/*!\n * \\brief The state of the running engine.\n * It contains the requests and their states submitted to the Engine.\n */\nclass EngineStateObj : public Object {\n public:\n  /*! \\brief The requests being processed. */\n  std::vector<Request> running_queue;\n  /*! \\brief The requests that have not started for process yet. */\n  std::vector<Request> waiting_queue;\n  /*! \\brief The states of all requests. */\n  std::unordered_map<String, RequestState> request_states;\n  /*! \\brief The internal id manager. */\n  EngineInternalIDManager id_manager;\n  /*! \\brief Runtime metrics. */\n  EngineMetrics metrics;\n  /*! \\brief The prefix cache. */\n  PrefixCache prefix_cache{nullptr};\n  /*! \\brief A boolean flag denoting whether the running request state entry list has changed. */\n  bool running_rsentries_changed = true;\n  /*!\n   * \\brief The current engine speculative decoding draft length.\n   * The length may change across time under the auto speculative decoding mode.\n   * Value 0 means undefined. It must have a positive value for speculative decoding to\n   * properly work.\n   */\n  int spec_draft_length = 0;\n  /*! \\brief A boolean flag denoting whether the engine is in disaggregation mode. */\n  bool disaggregation = false;\n  // Request stream callback function\n  FRequestStreamCallback request_stream_callback_;\n  /*!\n   * \\brief The post-process data structures.\n   * We make it a workspace to avoid repetitive memory allocation/free in the action post process.\n   */\n  ActionPostProcessWorkspace postproc_workspace;\n\n  /*! \\brief Reset the engine state and clear the metrics. */\n  void Reset();\n  /*! \\brief Get the request state of the given request. */\n  RequestState GetRequestState(Request request);\n  /*! \\brief Return the running request state entries*/\n  const std::vector<RequestStateEntry>& GetRunningRequestStateEntries();\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<EngineStateObj>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO_FINAL(\"mlc.serve.EngineState\", EngineStateObj, Object);\n\n private:\n  std::vector<RequestStateEntry> cached_running_rsentries_;\n};\n\n/*!\n * \\brief Managed reference of EngineStateObj.\n * \\sa EngineStateObj\n */\nclass EngineState : public ObjectRef {\n public:\n  explicit EngineState();\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(EngineState, ObjectRef, EngineStateObj);\n};\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_ENGINE_STATE_H_\n"
  },
  {
    "path": "cpp/serve/event_trace_recorder.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/event_trace_recorder.cc\n */\n#include \"event_trace_recorder.h\"\n\n#include <tvm/ffi/extra/json.h>\n#include <tvm/ffi/function.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/ffi/string.h>\n\n#include <algorithm>\n#include <chrono>\n#include <mutex>\n#include <unordered_map>\n#include <utility>\n#include <vector>\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing tvm::ffi::String;\n\nnamespace detail {\n\nstruct PairHash {\n  template <class T1, class T2>\n  std::size_t operator()(const std::pair<T1, T2>& p) const {\n    auto h1 = std::hash<T1>{}(p.first);\n    auto h2 = std::hash<T2>{}(p.second);\n    return h1 ^ h2;\n  }\n};\n\n}  // namespace detail\n\n/*! \\brief The implementation of event trace recorder. */\nclass EventTraceRecorderImpl : public EventTraceRecorderObj {\n public:\n  void AddEvent(const String& request_id, const std::string& event) final {\n    double event_time = std::chrono::duration_cast<std::chrono::duration<double>>(\n                            std::chrono::system_clock::now().time_since_epoch())\n                            .count();\n\n    {\n      std::lock_guard<std::mutex> lock(mutex_);\n      AddEventInternal(request_id, event, event_time);\n    }\n  }\n\n  void AddEvent(const Array<String>& request_ids, const std::string& event) final {\n    double event_time = std::chrono::duration_cast<std::chrono::duration<double>>(\n                            std::chrono::system_clock::now().time_since_epoch())\n                            .count();  // in seconds\n\n    {\n      std::lock_guard<std::mutex> lock(mutex_);\n      for (const String& request_id : request_ids) {\n        AddEventInternal(request_id, event, event_time);\n      }\n    }\n  }\n\n  std::string DumpJSON() final {\n    std::unordered_map<std::string, std::vector<std::pair<std::string, double>>> local_events;\n    {\n      std::lock_guard<std::mutex> lock(mutex_);\n      local_events = events_;\n    }\n\n    auto fcmp_events = [](const std::pair<int64_t, tvm::ffi::json::Value>& lhs,\n                          const std::pair<int64_t, tvm::ffi::json::Value>& rhs) {\n      return lhs.first < rhs.first;\n    };\n\n    tvm::ffi::json::Array event_array;\n    for (const std::string& request_id : request_id_in_order_) {\n      std::vector<std::pair<std::string, double>> event_pairs = local_events.at(request_id);\n      std::vector<std::pair<int64_t, tvm::ffi::json::Value>> events_to_sort;\n      events_to_sort.reserve(event_pairs.size());\n      for (int i = 0; i < static_cast<int>(event_pairs.size()); ++i) {\n        std::string event = event_pairs[i].first;\n        double event_time = event_pairs[i].second;\n        std::string name;\n        std::string phase;\n        if (event.compare(0, 6, \"start \") == 0) {\n          // Duration begin.\n          name = event.substr(6);\n          phase = \"B\";\n        } else if (event.compare(0, 7, \"finish \") == 0) {\n          // Duration end.\n          name = event.substr(7);\n          phase = \"E\";\n        } else {\n          // Instant event.\n          name = event;\n          phase = \"i\";\n        }\n        int64_t event_time_in_us = static_cast<int64_t>(event_time * 1e6);\n\n        tvm::ffi::json::Object event_json;\n        event_json.Set(\"name\", name);\n        event_json.Set(\"ph\", phase);\n        event_json.Set(\"ts\", event_time_in_us);\n        event_json.Set(\"pid\", static_cast<int64_t>(1));\n        event_json.Set(\"tid\", request_id);\n\n        events_to_sort.push_back({event_time_in_us, event_json});\n      }\n      std::sort(events_to_sort.begin(), events_to_sort.end(), fcmp_events);\n      for (auto [timestamp, event] : events_to_sort) {\n        event_array.push_back(std::move(event));\n      }\n    }\n    return tvm::ffi::json::Stringify(event_array);\n  }\n\n  TVM_FFI_DECLARE_OBJECT_INFO(\"mlc.serve.EventTraceRecorder\", EventTraceRecorderImpl,\n                              EventTraceRecorderObj);\n\n private:\n  /*! \\brief The internal impl of AddEvent, taking the event time as input. */\n  void AddEventInternal(const std::string& request_id, const std::string& event,\n                        double event_time) {\n    if (std::find(request_id_in_order_.begin(), request_id_in_order_.end(), request_id) ==\n        request_id_in_order_.end()) {\n      request_id_in_order_.push_back(request_id);\n    }\n    int event_cnt = event_counter_[{request_id, event}]++;\n    events_[request_id].push_back({event + \" (\" + std::to_string(event_cnt) + \")\", event_time});\n  }\n\n  /*! \\brief The mutex ensuring only one thread can access critical regions. */\n  std::mutex mutex_;\n\n  /************** Critical Regions **************/\n  /*! \\brief The request ids in time order. Each id only appears once. */\n  std::vector<std::string> request_id_in_order_;\n  /*! \\brief The number of a certain event for a request. */\n  std::unordered_map<std::pair<std::string, std::string>, int, detail::PairHash> event_counter_;\n  /*! \\brief The event list of each request together with the timestamps. */\n  std::unordered_map<std::string, std::vector<std::pair<std::string, double>>> events_;\n};\n\nEventTraceRecorder EventTraceRecorder::Create() {\n  return EventTraceRecorder(tvm::ffi::make_object<EventTraceRecorderImpl>());\n}\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  namespace refl = tvm::ffi::reflection;\n  EventTraceRecorderImpl::RegisterReflection();\n  refl::GlobalDef()\n      .def(\"mlc.serve.EventTraceRecorder\", []() { return EventTraceRecorder::Create(); })\n      .def(\"mlc.serve.EventTraceRecorderAddEvent\",\n           [](const EventTraceRecorder& trace_recorder, const String& request_id,\n              const std::string& event) { trace_recorder->AddEvent(request_id, event); })\n      .def_method(\"mlc.serve.EventTraceRecorderDumpJSON\", &EventTraceRecorderObj::DumpJSON);\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/event_trace_recorder.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/event_trace_recorder.h\n * \\brief The event trace recorder for requests in MLC LLM.\n */\n#ifndef MLC_LLM_SERVE_EVENT_TRACE_RECORDER_H_\n#define MLC_LLM_SERVE_EVENT_TRACE_RECORDER_H_\n\n#include <tvm/ffi/container/array.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/ffi/string.h>\n#include <tvm/runtime/object.h>\n\n#include <string>\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing namespace tvm::runtime;\nusing tvm::ffi::Array;\nusing tvm::ffi::String;\n\n/*! \\brief The event trace recorder for requests. */\nclass EventTraceRecorderObj : public Object {\n public:\n  /*!\n   * \\brief Record a event for the input request in the trace recorder.\n   * \\param request_id The subject request of the event.\n   * \\param event The event in a string name.\n   * It can have one of the following patterns:\n   * - \"start xxx\", which marks the start of event \"xxx\",\n   * - \"finish xxx\", which marks the finish of event \"xxx\",\n   * - \"yyy\", which marks the instant event \"yyy\".\n   * The \"starts\" and \"finishes\" will be automatically paired in the trace recorder.\n   */\n  virtual void AddEvent(const String& request_id, const std::string& event) = 0;\n\n  /*! \\brief Record a event for the list of input requests. */\n  virtual void AddEvent(const Array<String>& request_ids, const std::string& event) = 0;\n\n  /*! \\brief Dump the logged events in Chrome Trace Event Format in JSON string. */\n  virtual std::string DumpJSON() = 0;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<EventTraceRecorderObj>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO(\"mlc.serve.EventTraceRecorder\", EventTraceRecorderObj, Object);\n};\n\n/*!\n * \\brief Managed reference to EventTraceRecorderObj.\n * \\sa EventTraceRecorderObj\n */\nclass EventTraceRecorder : public ObjectRef {\n public:\n  /*! \\brief Create an event trace recorder. */\n  static EventTraceRecorder Create();\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(EventTraceRecorder, ObjectRef, EventTraceRecorderObj);\n};\n\n/****************** Helper macro ******************/\n\n/*! \\brief Record a event for the input request or list or requests. */\n#define RECORD_EVENT(trace_recorder, request_ids, event)  \\\n  if (trace_recorder.defined()) {                         \\\n    trace_recorder.value()->AddEvent(request_ids, event); \\\n  }\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_EVENT_TRACE_RECORDER_H_\n"
  },
  {
    "path": "cpp/serve/function_table.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/function_table.cc\n * \\brief The implementation of function table in serving for distributed inference.\n */\n\n#include \"function_table.h\"\n\n#include <tvm/ffi/function.h>\n#include <tvm/runtime/disco/session.h>\n#include <tvm/runtime/memory/memory_manager.h>\n#include <tvm/runtime/module.h>\n#include <tvm/runtime/tensor.h>\n\n#include <cstdlib>\n#include <filesystem>\n#include <string>\n#include <vector>\n\n#include \"../support/load_bytes_from_file.h\"\n#include \"../support/utils.h\"\n#include \"sampler/sampler.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nOptional<IntTuple> GetDiscoWorkerCPUBinding(int num_workers) {\n  const char* raw_cpu_binding = std::getenv(\"MLC_DISCO_WORKER_CPU_BINDING\");\n  if (raw_cpu_binding == nullptr) {\n    return std::nullopt;\n  }\n\n  std::string cpu_binding_str(raw_cpu_binding);\n  std::vector<std::string> cpu_ids_str = Split(cpu_binding_str, ',');\n  std::vector<int64_t> cpu_ids;\n  for (const std::string& cpu_id_str : cpu_ids_str) {\n    try {\n      cpu_ids.push_back(std::stol(cpu_id_str));\n    } catch (std::invalid_argument const& ex) {\n      LOG(FATAL) << \"Invalid MLC_DISCO_WORKER_CPU_BINDING \\\"\" << cpu_binding_str << \"\\\"\";\n    }\n  }\n  if (static_cast<int>(cpu_ids.size()) < num_workers) {\n    LOG(FATAL) << \"Insufficient number of specified CPU workers in MLC_DISCO_WORKER_CPU_BINDING, \"\n                  \"expecting at least \"\n               << num_workers << \"CPU ids but only \" << cpu_ids.size() << \" are given.\";\n  }\n\n  return IntTuple{cpu_ids};\n}\n\nFunction FunctionTable::SessionFuncAsPackedFunc(Session sess, DRef sess_func, String name) {\n  return Function([sess, func = std::move(sess_func), name = std::move(name)](\n                      ffi::PackedArgs args, ffi::Any* rv) -> void {\n    std::vector<AnyView> packed_args(args.size() + 3);\n    packed_args[0] = static_cast<int>(DiscoAction::kCallPacked);\n    packed_args[1] = 0;\n    packed_args[2] = func;\n    for (int i = 0; i < args.size(); ++i) {\n      packed_args[i + 3] = args[i];\n    }\n    *rv = sess->CallWithPacked(tvm::ffi::PackedArgs(packed_args.data(), packed_args.size()));\n  });\n}\n\nvoid FunctionTable::Init(String reload_lib_path, Device device, tvm::ffi::json::Object model_config,\n                         Optional<Session> session, int num_shards, int num_stages) {\n  local_gpu_device = device;\n  this->model_config = model_config;\n  this->cached_buffers = Map<String, ObjectRef>();\n\n  int num_workers = num_shards * num_stages;\n  if (num_workers > 1) {\n    TVM_FFI_ICHECK(session.defined());\n    this->sess = session.value();\n    this->use_disco = true;\n    this->disco_mod = sess->CallPacked(sess->GetGlobalFunc(\"runtime.disco.load_vm_module\"),\n                                       reload_lib_path, Optional<Device>(std::nullopt));\n    this->mod_get_func = [this, fmodule_get_function = sess->GetGlobalFunc(\n                                    \"ffi.ModuleGetFunction\")](const std::string& name) -> Function {\n      DRef func = sess->CallPacked(fmodule_get_function, this->disco_mod, name, true);\n      bool exists = (func->DebugGetFromRemote(0).as<Function>()) != nullptr;\n      if (!exists) {\n        return Function(nullptr);\n      }\n      return SessionFuncAsPackedFunc(sess, func, name);\n    };\n    if (num_stages == 1) {\n      if (Optional<IntTuple> cpu_ids = GetDiscoWorkerCPUBinding(/*num_workers=*/num_shards)) {\n        IntTuple cpu_ids_value = cpu_ids.value();\n        sess->CallPacked(sess->GetGlobalFunc(\"runtime.disco.bind_worker_to_cpu_core\"),\n                         cpu_ids_value);\n      }\n    }\n    this->get_global_func = [this](const std::string& name) -> Function {\n      return SessionFuncAsPackedFunc(sess, sess->GetGlobalFunc(name), name);\n    };\n    this->model_metadata_ = ModelMetadata::FromModule(\n        this->disco_mod.value()->DebugGetFromRemote(0).cast<Module>(), std::move(model_config));\n    this->_InitFunctions();\n  } else {\n    TVM_FFI_ICHECK(!session.defined());\n    Optional<Module> executable = std::nullopt;\n    Optional<Function> fload_exec;\n    if (StartsWith(reload_lib_path, \"system://\")) {\n      static Function f_load_system_lib = Function::GetGlobalRequired(\"ffi.SystemLib\");\n      std::string system_lib_prefix = std::string(reload_lib_path).substr(9);\n      std::replace(system_lib_prefix.begin(), system_lib_prefix.end(), /*old=*/'-', /*new=*/'_');\n      executable = f_load_system_lib(system_lib_prefix + \"_\").cast<Module>();\n      fload_exec = executable.value()->GetFunction(\"vm_load_executable\");\n      TVM_FFI_ICHECK(fload_exec.defined())\n          << \"Cannot find system lib with \" << system_lib_prefix\n          << \", please make sure you set model_lib field consistently with the compilation \";\n    } else {\n      executable = tvm::ffi::Module::LoadFromFile(reload_lib_path);\n      fload_exec = executable.value()->GetFunction(\"vm_load_executable\");\n      /* precompile opencl kernel programs */\n      if (device.device_type == kDLOpenCL) {\n        auto f_get = executable.value()->GetFunction(\"opencl.GetPreCompiledPrograms\", true);\n        TVM_FFI_ICHECK(f_get.defined()) << \"Cannot find opencl.GetPreCompiledPrograms\";\n        tvm::ffi::String bytes = f_get.value()().cast<String>();\n        auto f_set = executable.value()->GetFunction(\"opencl.SetPreCompiledPrograms\", true);\n        TVM_FFI_ICHECK(f_set.defined()) << \"Cannot find opencl.SetPreCompiledPrograms\";\n        f_set.value()(tvm::ffi::String(bytes));\n      }\n      TVM_FFI_ICHECK(fload_exec.defined()) << \"TVM runtime cannot find vm_load_executable\";\n    }\n    this->use_disco = false;\n    this->local_vm = fload_exec.value()().cast<Module>();\n    this->local_vm.value()\n        ->GetFunction(\"vm_initialization\")\n        .value()(static_cast<int>(device.device_type), device.device_id,\n                 static_cast<int>(tvm::runtime::memory::AllocatorType::kPooled),\n                 static_cast<int>(kDLCPU), 0,\n                 static_cast<int>(tvm::runtime::memory::AllocatorType::kPooled));\n    this->mod_get_func = [this](const std::string& name) -> Function {\n      return this->local_vm.value()->GetFunction(name, true).value_or(Function(nullptr));\n    };\n    this->get_global_func = [](const std::string& name) -> Function {\n      return Function::GetGlobalRequired(name);\n    };\n    this->model_metadata_ =\n        ModelMetadata::FromModule(this->local_vm.value(), std::move(model_config));\n    this->_InitFunctions();\n  }\n  TVM_FFI_ICHECK_EQ(this->model_metadata_.tensor_parallel_shards, num_shards);\n  TVM_FFI_ICHECK_EQ(this->model_metadata_.pipeline_parallel_stages, num_stages);\n  // Invoke the CUDA graph allocation init function if it is defined.\n  if (cuda_graph_alloc_init_func_.defined()) {\n    this->cuda_graph_alloc_init_func_();\n  }\n}\n\nObjectRef FunctionTable::LoadParams(const std::string& model_path, Device device) {\n  if (this->use_disco) {\n    Optional<DRef> params = std::nullopt;\n    if (this->model_metadata_.params.empty()) {\n      std::filesystem::path fs_model_path = model_path;\n      std::string metadata_path = (fs_model_path / \"tensor-cache.json\").string();\n      std::string tensor_cache_metadata = LoadBytesFromFile(metadata_path);\n      Function loader_create = this->get_global_func(\"runtime.disco.ShardLoader\");\n\n      auto load_all_func_name = \"runtime.disco.ShardLoaderLoadAll\";\n      Function loader_load_all = this->get_global_func(load_all_func_name);\n      TVM_FFI_ICHECK(loader_create != nullptr);\n      TVM_FFI_ICHECK(loader_load_all != nullptr);\n      DRef loader =\n          loader_create(metadata_path, tensor_cache_metadata, \"\", this->disco_mod).cast<DRef>();\n      params = loader_load_all(loader).cast<DRef>();\n    } else {\n      auto load_func_name = getenv(\"MLC_INTERNAL_PRESHARD_NUM\") == nullptr\n                                ? \"mlc.multi_gpu.LoadMultiGPU\"\n                                : \"mlc.multi_gpu.LoadMultiGPUPresharded\";\n      Function loader = this->get_global_func(load_func_name);\n      params = loader(model_path, this->disco_mod, tvm::ffi::json::Stringify(this->model_config))\n                   .cast<DRef>();\n    }\n    return params.value();\n  } else {\n    static Function fload_cache = Function::GetGlobalRequired(\"vm.builtin.tensor_cache.load\");\n    fload_cache(model_path, static_cast<int32_t>(device.device_type), device.device_id);\n    Array<Tensor> params;\n    if (this->model_metadata_.params.empty()) {\n      constexpr const char* name_loader = \"vm.builtin.param_array_from_cache\";\n      static Function fload_params = Function::GetGlobalRequired(name_loader);\n      params = fload_params(\"param\", -1).cast<Array<Tensor>>();\n    } else {\n      constexpr const char* name_loader = \"vm.builtin.param_array_from_cache_by_name\";\n      static Function fload_params = Function::GetGlobalRequired(name_loader);\n      Array<String> param_names;\n      param_names.reserve(this->model_metadata_.params.size());\n      for (const auto& param : this->model_metadata_.params) {\n        param_names.push_back(param.name);\n      }\n      params = fload_params(param_names).cast<Array<Tensor>>();\n    }\n    // after we get params, it is safe to simply clear the cached version\n    // as these params are referenced by params_\n    static Function fclear_tensor_cache =\n        Function::GetGlobalRequired(\"vm.builtin.tensor_cache.clear\");\n    fclear_tensor_cache();\n    return params;\n  }\n}\n\nvoid FunctionTable::_InitFunctions() {\n  this->embed_func_ = mod_get_func(\"embed\");\n  this->image_embed_func_ = mod_get_func(\"image_embed\");\n  this->single_batch_prefill_func_ = mod_get_func(\"prefill\");\n  this->single_batch_decode_func_ = mod_get_func(\"decode\");\n  this->single_batch_extend_func_ = mod_get_func(\"extend\");\n  this->prefill_func_ = mod_get_func(\"batch_prefill\");\n  this->decode_func_ = mod_get_func(\"batch_decode\");\n  this->extend_func_ = mod_get_func(\"batch_extend\");\n  this->verify_func_ = mod_get_func(\"batch_verify\");\n  this->single_batch_prefill_to_last_hidden_func_ = mod_get_func(\"prefill_to_last_hidden_states\");\n  this->single_batch_decode_to_last_hidden_func_ = mod_get_func(\"decode_to_last_hidden_states\");\n  this->prefill_to_last_hidden_func_ = mod_get_func(\"batch_prefill_to_last_hidden_states\");\n  this->decode_to_last_hidden_func_ = mod_get_func(\"batch_decode_to_last_hidden_states\");\n  this->verify_to_last_hidden_func_ = mod_get_func(\"batch_verify_to_last_hidden_states\");\n  this->fuse_embed_hidden_func_ = mod_get_func(\"fuse_embed_hidden_states\");\n  Module mod = this->use_disco ? this->disco_mod.value()->DebugGetFromRemote(0).cast<Module>()\n                               : this->local_vm.value();\n  this->get_logits_func_ = mod_get_func(\"get_logits\");\n  this->batch_get_logits_func_ = mod_get_func(\"batch_get_logits\");\n  this->batch_select_last_hidden_func_ = mod_get_func(\"batch_select_last_hidden_states\");\n  this->softmax_func_ =\n      mod->GetFunction(\"softmax_with_temperature\", true).value_or(Function(nullptr));\n  this->apply_logit_bias_func_ =\n      mod->GetFunction(\"apply_logit_bias_inplace\", true).value_or(Function(nullptr));\n  this->apply_penalty_func_ =\n      mod->GetFunction(\"apply_penalty_inplace\", true).value_or(Function(nullptr));\n  this->apply_bitmask_func_ =\n      mod->GetFunction(\"apply_bitmask_inplace\", true).value_or(Function(nullptr));\n  this->alloc_embedding_tensor_func_ = mod_get_func(\"alloc_embedding_tensor\");\n  this->cuda_graph_alloc_init_func_ = mod_get_func(\"cuda_graph_alloc_init\");\n  this->create_kv_cache_func_ = mod_get_func(\"create_flashinfer_paged_kv_cache\");\n  if (this->model_metadata_.sliding_window_size != -1 || !this->create_kv_cache_func_.defined()) {\n    Function f_create_rnn_state = mod_get_func(\"create_rnn_state\");\n    if (f_create_rnn_state.defined()) {\n      this->create_kv_cache_func_ = f_create_rnn_state;\n    } else {\n      this->create_kv_cache_func_ = mod_get_func(\"create_tir_paged_kv_cache\");\n    }\n  }\n  this->reset_kv_cache_func_ = get_global_func(\"vm.builtin.kv_state_clear\");\n  this->kv_cache_add_sequence_func_ = get_global_func(\"vm.builtin.kv_state_add_sequence\");\n  this->kv_cache_fork_sequence_func_ = get_global_func(\"vm.builtin.kv_state_fork_sequence\");\n  this->kv_cache_enable_sliding_window_for_seq_ =\n      get_global_func(\"vm.builtin.attention_kv_cache_enable_sliding_window_for_seq\");\n  this->kv_cache_remove_sequence_func_ = get_global_func(\"vm.builtin.kv_state_remove_sequence\");\n  this->kv_cache_begin_forward_func_ = get_global_func(\"vm.builtin.kv_state_begin_forward\");\n  this->kv_cache_end_forward_func_ = get_global_func(\"vm.builtin.kv_state_end_forward\");\n  this->kv_cache_disagg_prepare_recv_func_ =\n      get_global_func(\"vm.builtin.kv_cache_disagg_prepare_recv\");\n  this->kv_cache_disagg_mark_send_func_ = get_global_func(\"vm.builtin.kv_cache_disagg_mark_send\");\n  this->kv_cache_popn_func_ = get_global_func(\"vm.builtin.kv_state_popn\");\n  this->kv_cache_commit_accepted_token_tree_nodes_func_ =\n      get_global_func(\"vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes\");\n  this->kv_cache_get_num_available_pages_func_ =\n      Function::GetGlobalRequired(\"vm.builtin.attention_kv_cache_get_num_available_pages\");\n  this->kv_cache_get_total_sequence_length_func_ =\n      Function::GetGlobalRequired(\"vm.builtin.attention_kv_cache_get_total_sequence_length\");\n  if (Sampler::SupportGPUSampler(local_gpu_device)) {\n    gpu_multinomial_from_uniform_func_ =\n        mod->GetFunction(\"multinomial_from_uniform\", true).value_or(Function(nullptr));\n    gpu_argsort_probs_func_ = mod->GetFunction(\"argsort_probs\", true).value_or(Function(nullptr));\n    gpu_sample_with_top_p_func_ =\n        mod->GetFunction(\"sample_with_top_p\", true).value_or(Function(nullptr));\n    gpu_sampler_take_probs_func_ =\n        mod->GetFunction(\"sampler_take_probs\", true).value_or(Function(nullptr));\n    gpu_verify_draft_tokens_func_ =\n        mod->GetFunction(\"sampler_verify_draft_tokens\", true).value_or(Function(nullptr));\n    gpu_renormalize_by_top_p_func_ =\n        mod->GetFunction(\"renormalize_by_top_p\", true).value_or(Function(nullptr));\n  }\n  this->nd_view_func_ = get_global_func(\"vm.builtin.reshape\");\n  this->nd_get_shape_func_ = get_global_func(\"vm.builtin.shape_of\");\n  this->nd_copy_embedding_to_offset_func_ = get_global_func(\"mlc.copy_embedding_to_offset\");\n  support_backtracking_kv_ = true;\n  this->tuple_getitem_func_ = get_global_func(\"vm.builtin.tuple_getitem\");\n  if (use_disco) {\n    this->last_group_send_to_worker_0_ =\n        get_global_func(\"mlc.multi_gpu.SendFromLastGroupToWorker0\");\n  }\n\n  this->gather_probs_func_ = mod->GetFunction(\"gather_probs\", true).value_or(Function(nullptr));\n  this->scatter_probs_func_ = mod->GetFunction(\"scatter_probs\", true).value_or(Function(nullptr));\n  this->gather_hidden_states_func_ = mod_get_func(\"gather_hidden_states\");\n  this->scatter_hidden_states_func_ = mod_get_func(\"scatter_hidden_states\");\n}\n\nObjectRef FunctionTable::Empty(Shape shape, DataType dtype, Device device,\n                               bool worker0_only) const {\n  if (this->use_disco) {\n    DRef empty_func = sess->GetGlobalFunc(\"runtime.disco.empty\");\n    return sess->CallPacked(empty_func, shape, dtype, Optional<Device>(std::nullopt), worker0_only,\n                            /*in_group=*/false);\n  } else {\n    return Tensor::Empty(shape, dtype, device);\n  }\n}\n\nObjectRef FunctionTable::CopyToWorker0(const Tensor& host_array, String buffer_cache_key,\n                                       Shape max_reserved_shape, bool local_only) {\n  Map<String, ObjectRef> cached_buffers = this->cached_buffers.value();\n  if (this->use_disco && !local_only) {\n    Device null_device{DLDeviceType(0), 0};\n    Optional<DRef> buffer = std::nullopt;\n    auto it = cached_buffers.find(buffer_cache_key);\n    if (it != cached_buffers.end()) {\n      buffer = Downcast<DRef>((*it).second);\n    } else {\n      buffer = Downcast<DRef>(this->Empty(max_reserved_shape, host_array.DataType(), null_device,\n                                          /*worker0_only=*/false));\n      cached_buffers.Set(buffer_cache_key, buffer.value());\n    }\n    Shape real_shape = host_array.Shape();\n    DRef buffer_view = nd_view_func_(buffer.value(), real_shape).cast<DRef>();\n    sess->CopyToWorker0(host_array, buffer_view);\n    return buffer_view;\n  } else {\n    auto it = cached_buffers.find(buffer_cache_key);\n    Tensor buffer{nullptr};\n    if (it != cached_buffers.end()) {\n      buffer = Downcast<Tensor>((*it).second);\n      if (buffer_cache_key == \"image\") {\n        if (runtime::GetDataSize(*buffer.operator->()) <\n            runtime::GetDataSize(*host_array.operator->())) {\n          buffer = Tensor::Empty(max_reserved_shape, host_array->dtype, local_gpu_device);\n          cached_buffers.Set(buffer_cache_key, buffer);\n        }\n      }\n    } else {\n      buffer = Tensor::Empty(max_reserved_shape, host_array->dtype, local_gpu_device);\n      cached_buffers.Set(buffer_cache_key, buffer);\n    }\n    buffer = buffer.CreateView(host_array.Shape(), host_array->dtype);\n    DLTensor copy_dst = *(buffer.operator->());\n    Tensor::CopyFromTo(host_array.operator->(), &copy_dst);\n    return buffer;\n  }\n}\n\nvoid FunctionTable::DebugCallFuncOnAllAllWorker(const String& func_name,\n                                                Optional<String> func_args) const {\n  if (func_args) {\n    std::string args = func_args.value();\n    if (this->use_disco) {\n      sess->CallPacked(sess->GetGlobalFunc(func_name), args);\n    } else {\n      static Function func = Function::GetGlobalRequired(func_name);\n      func(args);\n    }\n  } else {\n    if (this->use_disco) {\n      sess->CallPacked(sess->GetGlobalFunc(func_name));\n    } else {\n      static Function func = Function::GetGlobalRequired(func_name);\n      func();\n    }\n  }\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/function_table.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/function_table.h\n * \\brief The header for function table in serving for distributed inference.\n */\n\n#ifndef MLC_LLM_SERVE_FUNCTION_TABLE_H_\n#define MLC_LLM_SERVE_FUNCTION_TABLE_H_\n\n#include <tvm/ffi/container/map.h>\n#include <tvm/ffi/extra/json.h>\n#include <tvm/ffi/extra/module.h>\n#include <tvm/ffi/function.h>\n#include <tvm/ffi/optional.h>\n#include <tvm/runtime/disco/session.h>\n#include <tvm/runtime/module.h>\n#include <tvm/runtime/tensor.h>\n\n#include <string>\n\n#include \"../metadata/model.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing tvm::Device;\nusing namespace tvm::runtime;\nusing tvm::ffi::Function;\nusing tvm::ffi::Map;\nusing tvm::ffi::Optional;\nusing tvm::ffi::TypedFunction;\n\n//--------------------------------------------------------\n// The function table under batching settings.\n// The implementation is mostly the same as the one for\n// single-sequence distributed inference in llm_chat.cc.\n// The only difference is that the function table for\n// batching uses a different set of packed functions.\n//\n// Here we choose to have the duplicate code instead of\n// reusing the existing function table. This is mainly\n// for the independent development of batching/serving\n// and make the codebase manageable.\n// We will eventually merge two implementation into one\n// after the batching development becomes stable.\n//--------------------------------------------------------\nstruct FunctionTable {\n  static Function SessionFuncAsPackedFunc(Session sess, DRef sess_func, String name);\n\n  void Init(String reload_lib_path, Device device, tvm::ffi::json::Object model_config,\n            Optional<Session> session, int num_shards, int num_stages);\n\n  ObjectRef LoadParams(const std::string& model_path, Device device);\n\n  void _InitFunctions();\n\n  ObjectRef Empty(Shape shape, DataType dtype, Device device, bool worker0_only) const;\n\n  /*!\n   * \\brief Copy a host array to the worker or local gpu.\n   * \\param host_array The host array to be copied.\n   * \\param buffer_cache_key The key to the buffer cache.\n   * \\param max_reserved_shape The maximum shape to be reserved in the buffer cache.\n   * \\param local_only Whether to copy the array to the local gpu only. If true, the use_disco\n   *                  flag will be ignored. This can be useful for functions that run only on the\n   *                  local gpu when disco is enabled.\n   * \\return The array on the worker or local gpu.\n   */\n  ObjectRef CopyToWorker0(const Tensor& host_array, String buffer_cache_key,\n                          Shape max_reserved_shape, bool local_only = false);\n\n  void DebugCallFuncOnAllAllWorker(const String& func_name, Optional<String> func_args) const;\n\n  bool use_disco = false;\n  Device local_gpu_device;\n  Session sess{nullptr};\n  Optional<DRef> disco_mod = std::nullopt;\n  Optional<Map<String, ObjectRef>> cached_buffers = std::nullopt;\n  Optional<tvm::ffi::Module> local_vm = std::nullopt;\n  tvm::ffi::json::Object model_config;\n\n  TypedFunction<Function(const std::string&)> mod_get_func;\n  TypedFunction<Function(const std::string&)> get_global_func;\n\n  ModelMetadata model_metadata_;\n\n  Function embed_func_;\n  Function image_embed_func_;\n  Function single_batch_prefill_func_;\n  Function single_batch_decode_func_;\n  Function single_batch_extend_func_;\n  Function prefill_func_;\n  Function decode_func_;\n  Function extend_func_;\n  Function verify_func_;\n  Function single_batch_prefill_to_last_hidden_func_;\n  Function single_batch_decode_to_last_hidden_func_;\n  Function prefill_to_last_hidden_func_;\n  Function decode_to_last_hidden_func_;\n  Function verify_to_last_hidden_func_;\n  Function fuse_embed_hidden_func_;\n  Function get_logits_func_;\n  Function batch_get_logits_func_;\n  Function batch_select_last_hidden_func_;\n  Function softmax_func_;\n  Function apply_logit_bias_func_;\n  Function apply_penalty_func_;\n  Function apply_bitmask_func_;\n  Function alloc_embedding_tensor_func_;\n  Function cuda_graph_alloc_init_func_;\n  Function create_kv_cache_func_;\n  Function reset_kv_cache_func_;\n  bool support_backtracking_kv_;\n  Function kv_cache_add_sequence_func_;\n  Function kv_cache_fork_sequence_func_;\n  Function kv_cache_enable_sliding_window_for_seq_;\n  Function kv_cache_remove_sequence_func_;\n  Function kv_cache_begin_forward_func_;\n  Function kv_cache_end_forward_func_;\n  Function kv_cache_disagg_prepare_recv_func_;\n  Function kv_cache_disagg_mark_send_func_;\n  Function kv_cache_popn_func_;\n  Function kv_cache_commit_accepted_token_tree_nodes_func_;\n  Function kv_cache_get_num_available_pages_func_;\n  Function kv_cache_get_total_sequence_length_func_;\n  Function gpu_multinomial_from_uniform_func_;\n  Function gpu_argsort_probs_func_;\n  Function gpu_sample_with_top_p_func_;\n  Function gpu_sampler_take_probs_func_;\n  Function gpu_verify_draft_tokens_func_;\n  Function gpu_renormalize_by_top_p_func_;\n  Function nd_view_func_;\n  Function nd_get_shape_func_;\n  Function nd_copy_embedding_to_offset_func_;\n  Function tuple_getitem_func_;\n  Function last_group_send_to_worker_0_;\n  // Auxiliary functions for speculative decoding.\n  Function gather_probs_func_;\n  Function scatter_probs_func_;\n  Function gather_hidden_states_func_;\n  Function scatter_hidden_states_func_;\n};\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_FUNCTION_TABLE_H_\n"
  },
  {
    "path": "cpp/serve/logit_processor.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/logit_processor.cc\n * \\brief The implementation of logit processor.\n */\n#include \"logit_processor.h\"\n\n#include <tvm/ffi/function.h>\n#include <tvm/runtime/device_api.h>\n#include <tvm/runtime/nvtx.h>\n#include <tvm/runtime/threading_backend.h>\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\ninline void CopyArray(Tensor src, Tensor dst, TVMStreamHandle copy_stream) {\n  DLTensor dl_dst = *(dst.operator->());\n  Tensor::CopyFromTo(src.operator->(), &dl_dst, copy_stream);\n}\n\ninline void SyncCopyStream(Device device, TVMStreamHandle compute_stream,\n                           TVMStreamHandle copy_stream) {\n  // - If there is no particular copy stream, no action is needed.\n  if (copy_stream == nullptr) {\n    return;\n  }\n  // - Sync two streams.\n  DeviceAPI::Get(device)->SyncStreamFromTo(device, copy_stream, compute_stream);\n}\n\n/***************** LogitProcessor Implementation *****************/\n\nTVM_FFI_STATIC_INIT_BLOCK() { LogitProcessorObj::RegisterReflection(); }\n\nclass LogitProcessorImpl : public LogitProcessorObj {\n public:\n  /*! * \\brief Constructor of LogitProcessorImpl. */\n  explicit LogitProcessorImpl(int max_num_token, int vocab_size, FunctionTable* ft, DLDevice device,\n                              Optional<EventTraceRecorder> trace_recorder)\n      : max_num_token_(max_num_token),\n        vocab_size_(vocab_size),\n        bitmask_size_((vocab_size + 31) / 32),\n        softmax_func_(ft->softmax_func_),\n        device_(device),\n        apply_logit_bias_func_(ft->apply_logit_bias_func_),\n        apply_penalty_func_(ft->apply_penalty_func_),\n        apply_bitmask_func_(ft->apply_bitmask_func_),\n        trace_recorder_(std::move(trace_recorder)) {\n    Device preferred_host_device = GetPreferredHostDevice(device);\n    // Initialize auxiliary arrays on CPU.\n    seq_ids_host_ = Tensor::Empty({max_num_token}, dtype_i32_, preferred_host_device);\n    pos2seq_id_host_ =\n        Tensor::Empty({max_num_token * vocab_size}, dtype_i32_, preferred_host_device);\n    token_ids_host_ =\n        Tensor::Empty({max_num_token * vocab_size}, dtype_i32_, preferred_host_device);\n    token_cnt_host_ =\n        Tensor::Empty({max_num_token * vocab_size}, dtype_i32_, preferred_host_device);\n    token_logit_bias_host_ =\n        Tensor::Empty({max_num_token * vocab_size}, dtype_f32_, preferred_host_device);\n    penalties_host_ = Tensor::Empty({max_num_token, 3}, dtype_f32_, preferred_host_device);\n    bitmask_host_ =\n        Tensor::Empty({max_num_token, bitmask_size_}, dtype_i32_, preferred_host_device);\n    temperature_host_ = Tensor::Empty({max_num_token}, dtype_f32_, preferred_host_device);\n    // Initialize auxiliary arrays on GPU.\n    seq_ids_device_ = Tensor::Empty({max_num_token}, dtype_i32_, device);\n    pos2seq_id_device_ = Tensor::Empty({max_num_token * vocab_size}, dtype_i32_, device);\n    token_ids_device_ = Tensor::Empty({max_num_token * vocab_size}, dtype_i32_, device);\n    token_cnt_device_ = Tensor::Empty({max_num_token * vocab_size}, dtype_i32_, device);\n    token_logit_bias_device_ = Tensor::Empty({max_num_token * vocab_size}, dtype_f32_, device);\n    penalties_device_ = Tensor::Empty({max_num_token, 3}, dtype_f32_, device);\n    bitmask_device_ = Tensor::Empty({max_num_token, bitmask_size_}, dtype_i32_, device);\n    temperature_device_ = Tensor::Empty({max_num_token}, dtype_f32_, device);\n\n    TVM_FFI_ICHECK(apply_logit_bias_func_.defined())\n        << \"Function \\\"apply_logit_bias_inplace\\\" not found in model\";\n    TVM_FFI_ICHECK(apply_penalty_func_.defined())\n        << \"Function \\\"apply_penalty_inplace\\\" not found in model\";\n    TVM_FFI_ICHECK(apply_bitmask_func_.defined())\n        << \"Function \\\"apply_bitmask_inplace\\\" not found in model\";\n\n    // If the device is CUDA/ROCm, we create a standalone copy stream, in\n    // purpose to hide the latency of auxiliary stream copy.\n    if (device.device_type == DLDeviceType::kDLCUDA ||\n        device.device_type == DLDeviceType::kDLROCM) {\n      // The compute stream is the default stream.\n      compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device);\n      copy_stream_ = DeviceAPI::Get(device)->CreateStream(device);\n    }\n  }\n\n  ~LogitProcessorImpl() {\n    // Free the copy stream if defined.\n    if (copy_stream_ != nullptr) {\n      DeviceAPI::Get(device_)->FreeStream(device_, copy_stream_);\n    }\n  }\n\n  void InplaceUpdateLogits(Tensor logits,                                  //\n                           const Array<GenerationConfig>& generation_cfg,  //\n                           const Array<RequestModelState>& mstates,        //\n                           const Array<String>& request_ids,               //\n                           const std::vector<int>* cum_num_token,          //\n                           const Array<RequestModelState>* draft_mstates,  //\n                           const std::vector<std::vector<int>>* draft_token_indices) final {\n    NVTXScopedRange nvtx_scope(\"Logit inplace update\");\n    TVM_FFI_ICHECK_EQ(logits->ndim, 2);\n    TVM_FFI_ICHECK_EQ(logits->shape[1], vocab_size_);\n    TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32));\n    TVM_FFI_ICHECK_EQ(generation_cfg.size(), mstates.size());\n    TVM_FFI_ICHECK_LE(logits->shape[0], max_num_token_);\n    int num_total_token = logits->shape[0];\n    int num_sequence = generation_cfg.size();\n\n    TVM_FFI_ICHECK((draft_mstates == nullptr) == (draft_token_indices == nullptr));\n    if (cum_num_token != nullptr) {\n      TVM_FFI_ICHECK(draft_mstates != nullptr);\n      TVM_FFI_ICHECK_EQ(cum_num_token->size(), num_sequence + 1);\n      TVM_FFI_ICHECK_EQ(cum_num_token->back(), num_total_token);\n    } else {\n      TVM_FFI_ICHECK_EQ(num_sequence, num_total_token);\n    }\n\n    if (draft_mstates != nullptr) {\n      TVM_FFI_ICHECK_EQ(draft_mstates->size(), num_sequence);\n      TVM_FFI_ICHECK_EQ(draft_token_indices->size(), num_sequence);\n    }\n\n    RECORD_EVENT(trace_recorder_, request_ids, \"start update logits\");\n\n    // Update 1. logit bias\n    RECORD_EVENT(trace_recorder_, request_ids, \"start apply logit bias\");\n    UpdateWithLogitBias(logits, generation_cfg, cum_num_token);\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish apply logit bias\");\n\n    // Update 2. penalties\n    RECORD_EVENT(trace_recorder_, request_ids, \"start apply penalty\");\n    UpdateWithPenalty(logits, generation_cfg, mstates, cum_num_token, draft_mstates,\n                      draft_token_indices);\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish apply penalty\");\n\n    // Update 3. Vocabulary mask.\n    // Note: The mask application must be placed as the last step in logit processor.\n    // This is because the masked logits are set to the minimal value.\n    // Further logit subtraction may cause issue such as underflow.\n    RECORD_EVENT(trace_recorder_, request_ids, \"start apply logit mask\");\n    UpdateWithMask(logits, mstates, cum_num_token, draft_mstates, draft_token_indices);\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish apply logit mask\");\n\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish update logits\");\n  }\n\n  Tensor ComputeProbsFromLogits(Tensor logits, const Array<GenerationConfig>& generation_cfg,\n                                const Array<String>& request_ids,\n                                const std::vector<int>* cum_num_token) final {\n    NVTXScopedRange nvtx_scope(\"Compute probs from logits\");\n    // logits: (n, v)\n    TVM_FFI_ICHECK_EQ(logits->ndim, 2);\n    TVM_FFI_ICHECK_LE(logits->shape[0], max_num_token_);\n    TVM_FFI_ICHECK_EQ(logits->shape[1], vocab_size_);\n    TVM_FFI_ICHECK(logits.DataType() == DataType::Float(32));\n    int num_total_token = logits->shape[0];\n    int num_sequence = generation_cfg.size();\n\n    if (cum_num_token != nullptr) {\n      TVM_FFI_ICHECK_EQ(cum_num_token->size(), num_sequence + 1);\n      TVM_FFI_ICHECK_EQ(cum_num_token->back(), num_total_token);\n    } else {\n      TVM_FFI_ICHECK_EQ(num_sequence, num_total_token);\n    }\n\n    RECORD_EVENT(trace_recorder_, request_ids, \"start softmax\");\n\n    // Construct:\n    // - temperature (max_num_token,) float32\n    float* p_temperature = static_cast<float*>(temperature_host_->data);\n\n    // - Set arrays.\n    for (int i = 0; i < num_sequence; ++i) {\n      int num_token_to_process =\n          cum_num_token == nullptr ? 1 : (cum_num_token->at(i + 1) - cum_num_token->at(i));\n      int token_offset = cum_num_token == nullptr ? i : cum_num_token->at(i);\n      for (int j = 0; j < num_token_to_process; ++j) {\n        p_temperature[token_offset + j] = std::max(generation_cfg[i]->temperature, 0.0);\n      }\n    }\n\n    // - View arrays.\n    Tensor temperature_host = temperature_host_.CreateView({num_total_token}, dtype_f32_);\n    Tensor temperature_device = temperature_device_.CreateView({num_total_token}, dtype_f32_);\n\n    // - Copy arrays to GPU.\n    CopyArray(/*src=*/temperature_host, /*dst=*/temperature_device, copy_stream_);\n    SyncCopyStream(device_, compute_stream_, copy_stream_);\n\n    // - Call kernel.\n    Tensor probs = softmax_func_(logits.CreateView({num_total_token, 1, vocab_size_}, dtype_f32_),\n                                 temperature_device)\n                       .cast<Tensor>();\n    TVM_FFI_ICHECK_EQ(probs->ndim, 3);\n    TVM_FFI_ICHECK_EQ(probs->shape[0], num_total_token);\n    TVM_FFI_ICHECK_EQ(probs->shape[1], 1);\n    TVM_FFI_ICHECK_EQ(probs->shape[2], vocab_size_);\n    if (trace_recorder_.defined()) {\n      DeviceAPI::Get(device_)->StreamSync(device_, /*stream=*/nullptr);\n    }\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish softmax\");\n    return probs.CreateView({num_total_token, vocab_size_}, probs->dtype);\n  }\n\n private:\n  void UpdateWithLogitBias(Tensor logits, const Array<GenerationConfig>& generation_cfg,\n                           const std::vector<int>* cum_num_token) {\n    NVTXScopedRange nvtx_scope(\"UpdateWithLogitBias\");\n    // Construct:\n    // - pos2seq_id (max_num_token * vocab_size,) int32\n    // - token_ids (max_num_token * vocab_size,) int32\n    // - token_logit_bias (max_num_token * vocab_size,) float32\n    int* p_pos2seq_id = static_cast<int*>(pos2seq_id_host_->data);\n    int* p_token_ids = static_cast<int*>(token_ids_host_->data);\n    float* p_token_logit_bias = static_cast<float*>(token_logit_bias_host_->data);\n\n    // - Set arrays.\n    int num_token_for_bias = 0;\n    int num_bias_token = 0;\n    for (int i = 0; i < static_cast<int>(generation_cfg.size()); ++i) {\n      int num_token_to_process =\n          cum_num_token == nullptr ? 1 : (cum_num_token->at(i + 1) - cum_num_token->at(i));\n      int token_offset = cum_num_token == nullptr ? i : cum_num_token->at(i);\n      for (int j = 0; j < num_token_to_process; ++j) {\n        if (!generation_cfg[i]->logit_bias.empty()) {\n          for (auto [token_id, bias] : generation_cfg[i]->logit_bias) {\n            p_pos2seq_id[num_bias_token] = token_offset + j;\n            p_token_ids[num_bias_token] = token_id;\n            p_token_logit_bias[num_bias_token] = bias;\n            ++num_bias_token;\n          }\n          ++num_token_for_bias;\n        }\n      }\n    }\n\n    if (num_token_for_bias == 0) {\n      return;\n    }\n\n    // - View arrays.\n    int num_token = num_bias_token;\n    Tensor pos2seq_id_host = pos2seq_id_host_.CreateView({num_token}, dtype_i32_);\n    Tensor pos2seq_id_device = pos2seq_id_device_.CreateView({num_token}, dtype_i32_);\n    Tensor token_ids_host = token_ids_host_.CreateView({num_token}, dtype_i32_);\n    Tensor token_ids_device = token_ids_device_.CreateView({num_token}, dtype_i32_);\n    Tensor token_logit_bias_host = token_logit_bias_host_.CreateView({num_token}, dtype_f32_);\n    Tensor token_logit_bias_device = token_logit_bias_device_.CreateView({num_token}, dtype_f32_);\n\n    // - Copy arrays to GPU.\n    CopyArray(/*src=*/pos2seq_id_host, /*dst=*/pos2seq_id_device, copy_stream_);\n    CopyArray(/*src=*/token_ids_host, /*dst=*/token_ids_device, copy_stream_);\n    CopyArray(/*src=*/token_logit_bias_host, /*dst=*/token_logit_bias_device, copy_stream_);\n    SyncCopyStream(device_, compute_stream_, copy_stream_);\n\n    // - Call kernel.\n    apply_logit_bias_func_(logits, pos2seq_id_device, token_ids_device, token_logit_bias_device);\n    if (trace_recorder_.defined()) {\n      DeviceAPI::Get(device_)->StreamSync(device_, nullptr);\n    }\n  }\n\n  void UpdateWithPenalty(Tensor logits, const Array<GenerationConfig>& generation_cfg,\n                         const Array<RequestModelState>& mstates,\n                         const std::vector<int>* cum_num_token,\n                         const Array<RequestModelState>* draft_mstates,\n                         const std::vector<std::vector<int>>* draft_token_indices) {\n    NVTXScopedRange nvtx_scope(\"UpdateWithPenalty\");\n    // Construct:\n    // - seq_ids (max_num_token,) int32\n    // - pos2seq_id (max_num_token * vocab_size,) int32\n    // - token_ids (max_num_token * vocab_size,) int32\n    // - token_cnt (max_num_token * vocab_size,) int32\n    // - penalties (max_num_token, 3) float32\n    int* p_seq_ids = static_cast<int*>(seq_ids_host_->data);\n    int* p_pos2seq_id = static_cast<int*>(pos2seq_id_host_->data);\n    int* p_token_ids = static_cast<int*>(token_ids_host_->data);\n    int* p_token_cnt = static_cast<int*>(token_cnt_host_->data);\n    float* p_penalties = static_cast<float*>(penalties_host_->data);\n\n    // - Set arrays.\n    int num_token_for_penalty = 0;\n    int num_penalty_appeared_token = 0;\n    for (int i = 0; i < static_cast<int>(generation_cfg.size()); ++i) {\n      if (generation_cfg[i]->frequency_penalty != 0.0 ||\n          generation_cfg[i]->presence_penalty != 0.0 ||\n          generation_cfg[i]->repetition_penalty != 1.0) {\n        int num_token_to_process =\n            cum_num_token == nullptr ? 1 : (cum_num_token->at(i + 1) - cum_num_token->at(i));\n        int token_offset = cum_num_token == nullptr ? i : cum_num_token->at(i);\n        TVM_FFI_ICHECK(num_token_to_process == 1 || mstates[i]->draft_output_tokens.empty());\n        TVM_FFI_ICHECK(draft_token_indices == nullptr ||\n                       draft_token_indices->at(i).size() == num_token_to_process);\n        for (int j = 0; j < num_token_to_process; ++j) {\n          p_seq_ids[num_token_for_penalty] = token_offset + j;\n\n          std::vector<SampleResult> draft_token_seq;\n          // Update appeared_token_ids with draft tokens\n          if (draft_token_indices != nullptr) {\n            int cur_draft_token_index = draft_token_indices->at(i)[j];\n            while (cur_draft_token_index != -1) {\n              draft_token_seq.push_back(\n                  (*draft_mstates)[i]->draft_output_tokens[cur_draft_token_index]);\n              cur_draft_token_index =\n                  (*draft_mstates)[i]->draft_token_parent_idx[cur_draft_token_index];\n            }\n          }\n          auto& appeared_token_ids = mstates[i]->appeared_token_ids;\n          for (const auto& token : draft_token_seq) {\n            appeared_token_ids[token.GetTokenId()] += 1;\n          }\n          for (auto [token_id, cnt] : appeared_token_ids) {\n            p_pos2seq_id[num_penalty_appeared_token] = num_token_for_penalty;\n            p_token_ids[num_penalty_appeared_token] = token_id;\n            p_token_cnt[num_penalty_appeared_token] = cnt;\n            ++num_penalty_appeared_token;\n          }\n          for (const auto& token : draft_token_seq) {\n            if ((--appeared_token_ids[token.GetTokenId()]) == 0) {\n              appeared_token_ids.erase(token.GetTokenId());\n            }\n          }\n          p_penalties[num_token_for_penalty * 3] = generation_cfg[i]->presence_penalty;\n          p_penalties[num_token_for_penalty * 3 + 1] = generation_cfg[i]->frequency_penalty;\n          p_penalties[num_token_for_penalty * 3 + 2] = generation_cfg[i]->repetition_penalty;\n          ++num_token_for_penalty;\n        }\n      }\n    }\n\n    if (num_token_for_penalty == 0) {\n      return;\n    }\n\n    // - View arrays.\n    int num_seq = num_token_for_penalty;\n    int num_token = num_penalty_appeared_token;\n    Tensor seq_ids_host = seq_ids_host_.CreateView({num_seq}, dtype_i32_);\n    Tensor seq_ids_device = seq_ids_device_.CreateView({num_seq}, dtype_i32_);\n    Tensor pos2seq_id_host = pos2seq_id_host_.CreateView({num_token}, dtype_i32_);\n    Tensor pos2seq_id_device = pos2seq_id_device_.CreateView({num_token}, dtype_i32_);\n    Tensor token_ids_host = token_ids_host_.CreateView({num_token}, dtype_i32_);\n    Tensor token_ids_device = token_ids_device_.CreateView({num_token}, dtype_i32_);\n    Tensor token_cnt_host = token_cnt_host_.CreateView({num_token}, dtype_i32_);\n    Tensor token_cnt_device = token_cnt_device_.CreateView({num_token}, dtype_i32_);\n    Tensor penalties_host = penalties_host_.CreateView({num_seq, 3}, dtype_f32_);\n    Tensor penalties_device = penalties_device_.CreateView({num_seq, 3}, dtype_f32_);\n\n    // - Copy arrays to GPU.\n    CopyArray(/*src=*/seq_ids_host, /*dst=*/seq_ids_device, copy_stream_);\n    CopyArray(/*src=*/pos2seq_id_host, /*dst=*/pos2seq_id_device, copy_stream_);\n    CopyArray(/*src=*/token_ids_host, /*dst=*/token_ids_device, copy_stream_);\n    CopyArray(/*src=*/token_cnt_host, /*dst=*/token_cnt_device, copy_stream_);\n    CopyArray(/*src=*/penalties_host, /*dst=*/penalties_device, copy_stream_);\n    SyncCopyStream(device_, compute_stream_, copy_stream_);\n\n    // - Call kernel.\n    apply_penalty_func_(logits, seq_ids_device, pos2seq_id_device, token_ids_device,\n                        token_cnt_device, penalties_device);\n    if (trace_recorder_.defined()) {\n      DeviceAPI::Get(device_)->StreamSync(device_, nullptr);\n    }\n  }\n\n  void UpdateWithMask(Tensor logits, const Array<RequestModelState>& mstates,\n                      const std::vector<int>* cum_num_token,\n                      const Array<RequestModelState>* draft_mstates,\n                      const std::vector<std::vector<int>>* draft_token_indices) {\n    NVTXScopedRange nvtx_scope(\"UpdateWithMask\");\n    // Construct:\n    // - seq_ids (max_num_token,) int32\n    // - bitmask (max_num_token, ceildiv(vocab_size, 32)), int32\n    int32_t* p_seq_ids = static_cast<int32_t*>(seq_ids_host_->data);\n    uint32_t* p_bitmask = static_cast<uint32_t*>(bitmask_host_->data);\n\n    // - Set arrays.\n    int batch_size = logits->shape[0];\n    TVM_FFI_ICHECK((cum_num_token == nullptr && batch_size == mstates.size()) ||\n                   (cum_num_token != nullptr && batch_size == cum_num_token->back()));\n\n    std::memset(p_seq_ids, 0, batch_size * sizeof(int32_t));\n\n    for (int i = 0; i < static_cast<int>(mstates.size()); ++i) {\n      int token_start_offset = cum_num_token == nullptr ? i : cum_num_token->at(i);\n      int token_number =\n          cum_num_token == nullptr ? 1 : (cum_num_token->at(i + 1) - cum_num_token->at(i));\n      bool require_mask = mstates[i]->RequireNextTokenBitmask();\n      TVM_FFI_ICHECK(draft_token_indices == nullptr ||\n                     draft_token_indices->at(i).size() == token_number);\n      for (int j = 0; j < token_number; ++j) {\n        if (require_mask) {\n          std::vector<SampleResult> draft_token_seq;\n          if (draft_token_indices != nullptr) {\n            int cur_draft_token_index = draft_token_indices->at(i)[j];\n            while (cur_draft_token_index != -1) {\n              draft_token_seq.push_back(\n                  (*draft_mstates)[i]->draft_output_tokens[cur_draft_token_index]);\n              cur_draft_token_index =\n                  (*draft_mstates)[i]->draft_token_parent_idx[cur_draft_token_index];\n            }\n            for (auto it = draft_token_seq.rbegin(); it != draft_token_seq.rend(); ++it) {\n              mstates[i]->grammar_matcher.value().AcceptToken(it->GetTokenId());\n            }\n          }\n          // Find a slice of bitmask_host_: bitmask_host_[num_token_for_mask, :]\n          DLTensor bitmask_dltensor = *bitmask_host_.operator->();\n          int64_t bitmask_shape[] = {bitmask_size_};\n          bitmask_dltensor.data = p_bitmask + (token_start_offset + j) * bitmask_size_;\n          bitmask_dltensor.shape = bitmask_shape;\n          bitmask_dltensor.ndim = 1;\n\n          mstates[i]->GetNextTokenBitmask(&bitmask_dltensor);\n          p_seq_ids[token_start_offset + j] = 1;\n\n          if (draft_token_seq.size() > 0) {\n            mstates[i]->grammar_matcher.value().Rollback(draft_token_seq.size());\n          }\n        }\n      }\n    }\n\n    int num_token_for_mask = 0;\n    for (int i = 0; i < batch_size; ++i) {\n      if (p_seq_ids[i] == 1) {\n        p_seq_ids[num_token_for_mask] = i;\n        ++num_token_for_mask;\n      }\n    }\n\n    if (num_token_for_mask == 0) {\n      return;\n    }\n\n    // - View arrays.\n    int num_seq = num_token_for_mask;\n    Tensor seq_ids_host = seq_ids_host_.CreateView({num_seq}, dtype_i32_);\n    Tensor seq_ids_device = seq_ids_device_.CreateView({num_seq}, dtype_i32_);\n    Tensor bitmask_host = bitmask_host_.CreateView({batch_size, bitmask_size_}, dtype_i32_);\n    Tensor bitmask_device = bitmask_device_.CreateView({batch_size, bitmask_size_}, dtype_i32_);\n\n    // - Copy arrays to GPU.\n    CopyArray(/*src=*/seq_ids_host, /*dst=*/seq_ids_device, copy_stream_);\n    CopyArray(/*src=*/bitmask_host, /*dst=*/bitmask_device, copy_stream_);\n    SyncCopyStream(device_, compute_stream_, copy_stream_);\n\n    // - Call kernel.\n    apply_bitmask_func_(logits, seq_ids_device, bitmask_device);\n    if (trace_recorder_.defined()) {\n      DeviceAPI::Get(device_)->StreamSync(device_, nullptr);\n    }\n  }\n\n  // Model configurations\n  const int max_num_token_;\n  const int vocab_size_;\n  const int bitmask_size_;\n  const DLDataType dtype_i32_ = DataType::Int(32);\n  const DLDataType dtype_u32_ = DataType::UInt(32);\n  const DLDataType dtype_f32_ = DataType::Float(32);\n  // Packed functions.\n  Device device_;\n  Function softmax_func_;\n  Function apply_logit_bias_func_;\n  Function apply_penalty_func_;\n  Function apply_bitmask_func_;\n  // Auxiliary Tensors on CPU\n  Tensor seq_ids_host_;\n  Tensor pos2seq_id_host_;\n  Tensor token_ids_host_;\n  Tensor token_cnt_host_;\n  Tensor token_logit_bias_host_;\n  Tensor penalties_host_;\n  Tensor bitmask_host_;\n  Tensor temperature_host_;\n  // Auxiliary Tensors on GPU\n  Tensor seq_ids_device_;\n  Tensor pos2seq_id_device_;\n  Tensor token_ids_device_;\n  Tensor token_cnt_device_;\n  Tensor token_logit_bias_device_;\n  Tensor penalties_device_;\n  Tensor bitmask_device_;\n  Tensor temperature_device_;\n  // Event trace recorder.\n  Optional<EventTraceRecorder> trace_recorder_;\n  // The device stream for the default computation operations.\n  TVMStreamHandle compute_stream_ = nullptr;\n  // The device stream for copying auxiliary data structure to GPU.\n  TVMStreamHandle copy_stream_ = nullptr;\n  // A small epsilon.\n  const double eps_ = 1e-5;\n};\n\nLogitProcessor::LogitProcessor(int max_num_token, int vocab_size, FunctionTable* ft,\n                               DLDevice device, Optional<EventTraceRecorder> trace_recorder) {\n  data_ = tvm::ffi::make_object<LogitProcessorImpl>(max_num_token, vocab_size, ft, device,\n                                                    std::move(trace_recorder));\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/logit_processor.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/logit_processor.h\n * \\brief The header for logit processor.\n */\n\n#ifndef MLC_LLM_SERVE_LOGIT_PROCESSOR_H_\n#define MLC_LLM_SERVE_LOGIT_PROCESSOR_H_\n\n#include <tvm/ffi/string.h>\n#include <tvm/runtime/module.h>\n\n#include \"../base.h\"\n#include \"config.h\"\n#include \"event_trace_recorder.h\"\n#include \"function_table.h\"\n#include \"request_state.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing tvm::Device;\nusing namespace tvm::runtime;\n\n/*!\n * \\brief The logit processor class that updates logits with regard\n * presence/frequency penalties, logit bias, etc..\n */\nclass LogitProcessorObj : public Object {\n public:\n  /*!\n   * \\brief In-place update a batch of logits with regard to the given\n   * generation config and request states.\n   * \\param logits The batch of raw logits, in shape (num_total_token, vocab_size),\n   * where `num_total_token` may be larger than the number of sequences\n   * indicated by `generation_cfg`, in which case some sequences may have\n   * more than one token.\n   * \\param generation_cfg The generation config of each sequence in the batch.\n   * \\param mstates The request states of each sequence in the batch.\n   * \\param request_ids The ids of each request.\n   * \\param cum_num_token The pointer to the cumulative token length of the sequences.\n   * If the pointer is nullptr, it means each sequence has only one token.\n   * \\param draft_mstates Optional. The draft request states of each sequence.\n   * \\param draft_token_indices Optional. The pointer to the draft token indices of each draft token\n   * in the model state (-1 indicates the token is not a draft token) when speculation is enabled.\n   * This is used to compute the sequence state with the draft tokens considered (the saved sequence\n   * state is not updated with the draft tokens).\n   */\n  virtual void InplaceUpdateLogits(\n      Tensor logits, const Array<GenerationConfig>& generation_cfg,\n      const Array<RequestModelState>& mstates, const Array<String>& request_ids,\n      const std::vector<int>* cum_num_token = nullptr,\n      const Array<RequestModelState>* draft_mstates = nullptr,\n      const std::vector<std::vector<int>>* draft_token_indices = nullptr) = 0;\n\n  /*!\n   * \\brief Compute probability distributions for the input batch of logits.\n   * \\param logits The batch of updated logits.\n   * \\param generation_cfg The generation config of each sequence in the batch.\n   * \\param request_ids The ids of each request.\n   * \\param cum_num_token The pointer to the cumulative token length of the sequences.\n   * If the pointer is nullptr, it means each sequence has only one token.\n   * \\return The batch of computed probability distributions on GPU.\n   */\n  virtual Tensor ComputeProbsFromLogits(Tensor logits,\n                                        const Array<GenerationConfig>& generation_cfg,\n                                        const Array<String>& request_ids,\n                                        const std::vector<int>* cum_num_token = nullptr) = 0;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<LogitProcessorObj>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO(\"mlc.serve.LogitProcessor\", LogitProcessorObj, Object);\n};\n\nclass LogitProcessor : public ObjectRef {\n public:\n  /*!\n   * \\brief Constructor.\n   * \\param max_num_token The max number of tokens in the token processor.\n   * \\param vocab_size The model's vocabulary size.\n   * \\param ft The packed function table.\n   * \\param device The device that the model runs on.\n   * \\param trace_recorder The event trace recorder.\n   */\n  explicit LogitProcessor(int max_num_token, int vocab_size, FunctionTable* ft, DLDevice device,\n                          Optional<EventTraceRecorder> trace_recorder);\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(LogitProcessor, ObjectRef, LogitProcessorObj);\n};\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_LOGIT_PROCESSOR_H_\n"
  },
  {
    "path": "cpp/serve/metrics.cc",
    "content": "\n/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/metrics.cc\n */\n#include \"metrics.h\"\n\n#include <tvm/runtime/logging.h>\n\n#include <sstream>\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\ntvm::ffi::json::Object TimeCost::AsJSON() const {\n  tvm::ffi::json::Object config;\n  config.Set(\"count\", count);\n  if (count != 0) {\n    config.Set(\"mean\", sum / count);\n  }\n  return config;\n}\n\ntvm::ffi::json::Object SpecDecodeMetrics::AsJSON() const {\n  tvm::ffi::json::Object metrics;\n  auto f_vector_to_array = [](const std::vector<int64_t>& vec) {\n    tvm::ffi::json::Array arr;\n    for (int64_t v : vec) {\n      arr.push_back(v);\n    }\n    return tvm::ffi::json::Value(arr);\n  };\n  metrics.Set(\"draft_count\", f_vector_to_array(draft_count));\n  metrics.Set(\"accept_count\", f_vector_to_array(accept_count));\n\n  TVM_FFI_ICHECK_EQ(draft_count.size(), accept_count.size());\n  // NOTE: label follows prometheus with full context\n  // so it can be flattened and used in metrics reoorting end point\n  tvm::ffi::json::Object accept_prob_metrics;\n  tvm::ffi::json::Object accept_rate_metrics;\n  tvm::ffi::json::Object accept_len_metrics;\n\n  double accept_len_value = 0;\n\n  for (size_t i = 0; i < draft_count.size(); ++i) {\n    std::ostringstream accept_prob_label;\n    accept_prob_label << \"accept_prob{step=\" << i << \"}\";\n    double accept_prob_value =\n        (static_cast<double>(accept_count[i]) / static_cast<double>(draft_count[i]));\n    accept_prob_metrics.Set(accept_prob_label.str(), accept_prob_value);\n    accept_len_value += accept_prob_value;\n\n    std::ostringstream accept_len_label;\n    accept_len_label << \"accept_len{step=\" << i << \"}\";\n    accept_len_metrics.Set(accept_len_label.str(), accept_len_value);\n\n    if (i != 0) {\n      std::ostringstream accept_rate_label;\n      accept_rate_label << \"accept_rate{step=\" << i << \"}\";\n      double accept_rate_value =\n          accept_count[i - 1] == 0\n              ? 0.0f\n              : (static_cast<double>(accept_count[i]) / static_cast<double>(accept_count[i - 1]));\n      accept_rate_metrics.Set(accept_rate_label.str(), accept_rate_value);\n    }\n  }\n  metrics.Set(\"accept_prob\", accept_prob_metrics);\n  metrics.Set(\"accept_rate\", accept_rate_metrics);\n  metrics.Set(\"accept_len\", accept_len_metrics);\n\n  return metrics;\n}\n\ntvm::ffi::json::Object RequestMetrics::AsJSON() const {\n  tvm::ffi::json::Object metrics;\n  metrics.Set(\"prompt_tokens\", prompt_tokens);\n  metrics.Set(\"completion_tokens\", completion_tokens);\n  metrics.Set(\"prefill_tokens\", prefill_tokens);\n  metrics.Set(\"decode_tokens\", decode_tokens);\n  metrics.Set(\"jump_forward_tokens\", jump_forward_tokens);\n\n  if (prefill_tokens != 0) {\n    metrics.Set(\"prefill_tokens_per_s\", prefill_tokens / this->GetPrefillTime());\n  }\n  if (decode_tokens != 0) {\n    metrics.Set(\"decode_tokens_per_s\", decode_tokens / this->GetDecodeTime());\n  }\n  metrics.Set(\"end_to_end_latency_s\", this->GetTotalTime());\n  metrics.Set(\"ttft_s\", this->GetTTFT());\n  metrics.Set(\"inter_token_latency_s\", this->GetInterTokenLatency());\n  return metrics;\n}\n\nstd::string RequestMetrics::AsUsageJSONStr(bool include_extra) const {\n  tvm::ffi::json::Object usage;\n  usage.Set(\"prompt_tokens\", prompt_tokens);\n  usage.Set(\"completion_tokens\", completion_tokens);\n  usage.Set(\"total_tokens\", prompt_tokens + completion_tokens);\n  if (include_extra) {\n    usage.Set(\"extra\", this->AsJSON());\n  }\n  return tvm::ffi::json::Stringify(usage);\n}\n\ntvm::ffi::json::Object EngineMetrics::AsJSON() const {\n  tvm::ffi::json::Object metrics;\n  metrics.Set(\"engine_prefill_time_sum\", engine_prefill_time_sum);\n  metrics.Set(\"engine_decode_time_sum\", engine_decode_time_sum);\n  metrics.Set(\"engine_jump_forward_time_sum\", engine_jump_forward_time_sum);\n  metrics.Set(\"prompt_tokens_sum\", prompt_tokens_sum);\n  metrics.Set(\"completion_tokens_sum\", completion_tokens_sum);\n  metrics.Set(\"prefill_tokens_sum\", prefill_tokens_sum);\n  metrics.Set(\"decode_tokens_sum\", decode_tokens_sum);\n  metrics.Set(\"jump_forward_tokens_sum\", jump_forward_tokens_sum);\n\n  if (prefill_tokens_sum != 0) {\n    metrics.Set(\"prefill_tokens_per_s\", prefill_tokens_sum / engine_prefill_time_sum);\n  }\n  if (engine_decode_time_sum != 0) {\n    metrics.Set(\"decode_tokens_per_s\", decode_tokens_sum / engine_decode_time_sum);\n  }\n\n  metrics.Set(\"last_finished_request\", last_finished_request.AsJSON());\n  if (!spec_decode.IsEmpty()) {\n    metrics.Set(\"spec_decode\", spec_decode.AsJSON());\n  }\n\n  auto f_create_time_list = [](const std::vector<TimeCost>& time_list) {\n    tvm::ffi::json::Object result;\n    for (size_t i = 1; i < time_list.size(); ++i) {\n      const TimeCost& item = time_list[i];\n      if (item.count == 0) continue;\n      std::ostringstream label_mean;\n      label_mean << \"mean{batch_size=\" << i << \"}\";\n      double mean = item.sum / item.count;\n      result.Set(label_mean.str(), mean);\n      std::ostringstream label_count;\n      label_count << \"count{batch_size=\" << i << \"}\";\n      result.Set(label_count.str(), item.count);\n    }\n    return tvm::ffi::json::Value(result);\n  };\n\n  metrics.Set(\"decode_time_by_batch_size\", f_create_time_list(decode_time_by_batch_size));\n  metrics.Set(\"draft_time_by_batch_size\", f_create_time_list(draft_time_by_batch_size));\n  metrics.Set(\"verify_time_by_batch_size\", f_create_time_list(verify_time_by_batch_size));\n\n  return metrics;\n}\n\nstd::string EngineMetrics::AsUsageJSONStr() const {\n  tvm::ffi::json::Object usage;\n  // We return engine usage as a usage field according to the OpenAI API.\n  // To comply with the API, just set prompt_tokens, completion_tokens, and total_tokens to 0.\n  // And store the information in the extra field.\n  usage.Set(\"prompt_tokens\", static_cast<int64_t>(0));\n  usage.Set(\"completion_tokens\", static_cast<int64_t>(0));\n  usage.Set(\"total_tokens\", static_cast<int64_t>(0));\n  usage.Set(\"extra\", this->AsJSON());\n  return tvm::ffi::json::Stringify(usage);\n}\n\nvoid EngineMetrics::Reset() {\n  engine_prefill_time_sum = 0.0;\n  engine_decode_time_sum = 0.0;\n  engine_jump_forward_time_sum = 0;\n  prompt_tokens_sum = 0;\n  completion_tokens_sum = 0;\n  prefill_tokens_sum = 0;\n  decode_tokens_sum = 0;\n  jump_forward_tokens_sum = 0;\n  last_finished_request.Reset();\n  spec_decode.Reset();\n  decode_time_by_batch_size.clear();\n  draft_time_by_batch_size.clear();\n  verify_time_by_batch_size.clear();\n  decode_time_by_batch_size.resize(kEndFineGrainedTrackingBatchSize);\n  draft_time_by_batch_size.resize(kEndFineGrainedTrackingBatchSize);\n  verify_time_by_batch_size.resize(kEndFineGrainedTrackingBatchSize);\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/metrics.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/metric.h\n * \\brief Metrics of serving engine/requests.\n */\n#ifndef MLC_LLM_SERVE_METRICS_H_\n#define MLC_LLM_SERVE_METRICS_H_\n\n#include <tvm/ffi/extra/json.h>\n#include <tvm/runtime/logging.h>\n\n#include <chrono>\n#include <string>\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\n// We keep all metrics containers in this header (instead of in Engine and Request State)\n// so we have a single central place to define all metrics across the engine.\n// Conceptually, these statistics are derived from engine/request behaviors.\n\n/*!\n * \\brief The class for tracking mean time cost.\n * - We maintain the number of updates (`count`) and the sum of updated values (`sum`).\n * - We support warmup. When `warmup` is false, the first update will be discarded.\n */\nstruct TimeCost {\n  /*! \\brief the total amount of cost excluding warm up time */\n  double sum = 0.0;\n  /*! \\brief the total count of events excluding warmup */\n  int64_t count = 0;\n  /*! \\brief Whether we warmed up already, assuming one hit is enough */\n  bool warmed_up = false;\n\n  /*! \\brief Update the metric with given value. */\n  void Update(double value) {\n    if (warmed_up) {\n      sum += value;\n      count += 1;\n    } else {\n      warmed_up = true;\n    }\n  }\n\n  /*! \\brief Reset the metric. */\n  void Reset() {\n    // NOTE: no need to redo warmup\n    // assuming we are measuring the same thing\n    this->sum = 0.0;\n    this->count = 0;\n  }\n\n  /*! \\brief Dump the metric as JSON. */\n  tvm::ffi::json::Object AsJSON() const;\n};\n\n/*! \\brief Runtime metrics for speculative decoding */\nstruct SpecDecodeMetrics {\n  /*! \\brief The number of draft tokens in speculative decoding, per step */\n  std::vector<int64_t> draft_count;\n  /*! \\brief The number of accepted tokens in speculative decoding, per step */\n  std::vector<int64_t> accept_count;\n\n  /*!\n   * \\brief Update the metrics of speculative decoding.\n   * \\param draft_length The number of draft tokens (including the last prediction by the base\n   * model)\n   * \\param accept_length The number of accepted tokens in the speculative decoding.\n   */\n  void Update(int draft_length, int accept_length) {\n    TVM_FFI_ICHECK_GE(accept_length, 1);\n    if (accept_count.size() < draft_length) {\n      this->accept_count.resize(draft_length, 0);\n      this->draft_count.resize(draft_length, 0);\n    }\n    for (int j = 0; j < draft_length; ++j) {\n      if (j < accept_length) {\n        ++this->accept_count[j];\n      }\n      ++this->draft_count[j];\n    }\n  }\n\n  bool IsEmpty() const { return draft_count.size() == 0; }\n\n  void Reset() {\n    accept_count.clear();\n    draft_count.clear();\n  }\n  tvm::ffi::json::Object AsJSON() const;\n};\n\n/*!\n * \\brief Metrics attached to each request\n *\n * Sometimes requests can involve tree decode(e.g. parallel n).\n * The metrics is collected across all branches of the tree.\n */\nstruct RequestMetrics {\n  /*! \\brief Request input tokens. */\n  int64_t prompt_tokens = 0;\n  /*! \\brief Total number of output tokens. */\n  int64_t completion_tokens = 0;\n  /*! \\brief Total number of tokens that needs to be prefilled */\n  int64_t prefill_tokens = 0;\n  /*! \\brief The number of processed tokens (including tokens rolled back later) in decode. */\n  int64_t decode_tokens = 0;\n  /*! \\brief The number of tokens predicted by jump-forward decoding. */\n  int64_t jump_forward_tokens = 0;\n\n  /*! \\brief The time of adding the request to engine. */\n  std::chrono::high_resolution_clock::time_point add_time_point;\n  /*! \\brief The time of finishing prefill stage. */\n  std::chrono::high_resolution_clock::time_point prefill_end_time_point;\n  /*! \\brief The time of finishing all decode. */\n  std::chrono::high_resolution_clock::time_point finish_time_point;\n\n  /*! \\brief check whether the request metrics is a completed request */\n  bool IsComplete() const { return prompt_tokens != 0 && completion_tokens != 0; }\n\n  /*! \\return the prefill time in seconds */\n  double GetPrefillTime() const {\n    return static_cast<double>((prefill_end_time_point - add_time_point).count()) / 1e9;\n  }\n\n  /*! \\return the decode time in seconds */\n  double GetDecodeTime() const {\n    return static_cast<double>((finish_time_point - prefill_end_time_point).count()) / 1e9;\n  }\n\n  /*! \\return the time to first token (TTFT) in seconds */\n  double GetTTFT() const {\n    return static_cast<double>((prefill_end_time_point - add_time_point).count()) / 1e9;\n  }\n\n  /*! \\return the prefill time in seconds */\n  double GetTotalTime() const {\n    return static_cast<double>((finish_time_point - add_time_point).count()) / 1e9;\n  }\n\n  /*! \\return the inter token latency (ITL) in seconds */\n  double GetInterTokenLatency() const {\n    return completion_tokens > 0 ? GetTotalTime() / completion_tokens : 0.0;\n  }\n\n  /*! \\brief Reset the metric. */\n  void Reset() {\n    this->prompt_tokens = 0;\n    this->prefill_tokens = 0;\n    this->completion_tokens = 0;\n  }\n  /*!\n   * \\brief Return the request metrics in JSON.\n   * \\return The metrics in JSON\n   */\n  tvm::ffi::json::Object AsJSON() const;\n  /*!\n   * \\brief Return OpenAI compatible usage metrics\n   * \\param include_extra Whether to include extra set of metrics\n   *\n   * \\return The usage metrics in json.\n   */\n  std::string AsUsageJSONStr(bool include_extra) const;\n};\n\n/*! \\brief Runtime metrics of engine. */\nstruct EngineMetrics {\n  /*! \\brief The total engine time on prefill, including warmup */\n  double engine_prefill_time_sum = 0;\n  /*! \\brief The total engine time on decode/draft/verify, including warmup */\n  double engine_decode_time_sum = 0;\n  /*! \\brief The total engine time on jump-forward prediction. */\n  double engine_jump_forward_time_sum = 0;\n  /*! \\brief The total number of request input tokens. */\n  int64_t prompt_tokens_sum = 0;\n  /*! \\brief The total number of request output tokens */\n  int64_t completion_tokens_sum = 0;\n  /*! \\brief The total number of processed tokens (excluding the prefix-cached length) in prefill */\n  int64_t prefill_tokens_sum = 0;\n  /*! \\brief The total number of processed tokens (including tokens rolled back later) in decode. */\n  int64_t decode_tokens_sum = 0;\n  /*! \\brief The total number of tokens predicted by jump-forward decoding. */\n  int64_t jump_forward_tokens_sum = 0;\n  /*! \\brief metrics from last finished request. */\n  RequestMetrics last_finished_request;\n  /*! \\brief speculative decoding metrics */\n  SpecDecodeMetrics spec_decode;\n\n  /*! \\brief The maximum batch size we track for batch decode time. */\n  static constexpr const int64_t kEndFineGrainedTrackingBatchSize = 65;\n  /*! \\brief The list of batch decode time under different batch size. */\n  std::vector<TimeCost> decode_time_by_batch_size =\n      std::vector<TimeCost>(kEndFineGrainedTrackingBatchSize);\n  /*! \\brief The list of batch draft time (a single decode step) under different batch size. */\n  std::vector<TimeCost> draft_time_by_batch_size =\n      std::vector<TimeCost>(kEndFineGrainedTrackingBatchSize);\n  /*! \\brief The list of batch verification time under different effective batch size. */\n  std::vector<TimeCost> verify_time_by_batch_size =\n      std::vector<TimeCost>(kEndFineGrainedTrackingBatchSize);\n\n  // NOTE: we keep most update function in header\n  // so they can be inlined effectively\n  /*!\n   * \\brief Update the batch decode time for the given batch size.\n   * The time will be ignored if the batch size is greater than `kMaxBatchSizeForTracking`.\n   */\n  void UpdateDecodeTimeByBatchSize(int batch_size, double time) {\n    if (batch_size < kEndFineGrainedTrackingBatchSize) {\n      decode_time_by_batch_size[batch_size].Update(time);\n    }\n  }\n  /*!\n   * \\brief Update the single-step batch draft time for the given batch size.\n   * The time will be ignored if the batch size is greater than `kMaxBatchSizeForTracking`.\n   */\n  void UpdateDraftTimeByBatchSize(int batch_size, double time) {\n    if (batch_size < kEndFineGrainedTrackingBatchSize) {\n      draft_time_by_batch_size[batch_size].Update(time);\n    }\n  }\n  /*!\n   * \\brief Update the batch decode time for the given effective batch sizPe.\n   * The time will be ignored if the effective batch size is greater than\n   * `kMaxBatchSizeForTracking`.\n   */\n  void UpdateVerifyTimeByBatchSize(int effective_batch_size, double time) {\n    if (effective_batch_size < kEndFineGrainedTrackingBatchSize) {\n      verify_time_by_batch_size[effective_batch_size].Update(time);\n    }\n  }\n\n  /*!\n   * \\brief Update global engine metrics as we finish a request\n   *  by including the information from the finished request.\n   */\n  void RequestFinishUpdate(const RequestMetrics& request_metrics) {\n    prompt_tokens_sum += request_metrics.prompt_tokens;\n    prefill_tokens_sum += request_metrics.prefill_tokens;\n    completion_tokens_sum += request_metrics.completion_tokens;\n    decode_tokens_sum += request_metrics.decode_tokens;\n    jump_forward_tokens_sum += request_metrics.jump_forward_tokens;\n    last_finished_request = request_metrics;\n  }\n  /*!\n   * \\brief Return the engine runtime metrics in JSON.\n   * \\return The metrics in JSON\n   */\n  tvm::ffi::json::Object AsJSON() const;\n\n  /*!\n   * \\brief return engine metrics as usage json string.\n   * \\return The resulting usage json string.\n   */\n  std::string AsUsageJSONStr() const;\n\n  /*! \\brief Reset all the metrics. */\n  void Reset();\n};\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_METRIC_H_\n"
  },
  {
    "path": "cpp/serve/model.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/model.cc\n * \\brief The implementation of runtime module of LLM functions (prefill/decode/etc.)\n */\n#include \"model.h\"\n\n#include <tvm/ffi/function.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/runtime/memory/memory_manager.h>\n#include <tvm/runtime/nvtx.h>\n\n#include <fstream>\n#include <unordered_set>\n\n#include \"../support/json_parser.h\"\n#include \"../support/vlm_utils.h\"\n#include \"config.h\"\n#include \"logit_processor.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\n/*********************** Model Implementation ***********************/\n\nTVM_FFI_STATIC_INIT_BLOCK() { ModelObj::RegisterReflection(); }\n\nclass ModelImpl;\n\nModel Model::Create(String reload_lib_path, String model_path,\n                    const tvm::ffi::json::Object& model_config, DLDevice device,\n                    const Optional<Session>& session, int num_shards, int num_stages,\n                    bool trace_enabled) {\n  return Model(tvm::ffi::make_object<ModelImpl>(reload_lib_path, model_path, model_config, device,\n                                                session, num_shards, num_stages, trace_enabled));\n}\n\nResult<tvm::ffi::json::Object> Model::LoadModelConfig(const String& model_path) {\n  using TResult = Result<tvm::ffi::json::Object>;\n  std::ifstream config_istream((model_path + \"/mlc-chat-config.json\").c_str());\n  std::ostringstream config_ostream;\n  TVM_FFI_ICHECK(config_istream);\n  config_ostream << config_istream.rdbuf();\n  std::string config_str = config_ostream.str();\n  tvm::ffi::String err;\n  auto config_json = tvm::ffi::json::Parse(config_str, &err);\n  if (!err.empty()) {\n    return TResult::Error(std::string(err));\n  }\n  auto opt = config_json.try_cast<tvm::ffi::json::Object>();\n  if (!opt.has_value()) {\n    return TResult::Error(\"Expected JSON object in model config\");\n  }\n  return TResult::Ok(*opt);\n}\n\nclass ModelImpl : public ModelObj {\n public:\n  /*!\n   * \\brief Constructor of ModelImpl.\n   * \\sa Model::Create\n   */\n  explicit ModelImpl(String reload_lib_path, String model_path, tvm::ffi::json::Object model_config,\n                     DLDevice device, const Optional<Session>& session, int num_shards,\n                     int num_stages, bool trace_enabled)\n      : model_(model_path), device_(device), trace_enabled_(trace_enabled) {\n    // Step 1. Process model config json string.\n    LoadModelConfigJSON(model_config);\n    // Step 2. Initialize vm, we use the packed function mechanism\n    // so there is no explicit abi dependency on these extra\n    // classes other than basic tvm runtime.\n    this->ft_.Init(reload_lib_path, device_, model_config, session, num_shards, num_stages);\n    this->num_shards_ = ft_.model_metadata_.tensor_parallel_shards;\n    this->num_stages_ = ft_.model_metadata_.pipeline_parallel_stages;\n    this->seqlen_padding_factor_ = ft_.model_metadata_.seqlen_padding_factor;\n    // Step 3. Reset\n    this->Reset();\n    // Step 4. Set model type\n    this->kind = GetMetadata().kv_state_kind;\n  }\n\n  /*********************** Model Computation  ***********************/\n\n  ObjectRef TokenEmbed(IntTuple token_ids, ObjectRef* dst, int offset) final {\n    NVTXScopedRange nvtx_scope(\"TokenEmbed\");\n    int num_tokens = token_ids.size();\n    if (seqlen_padding_factor_ > 1) {\n      num_tokens = (offset + num_tokens + seqlen_padding_factor_ - 1) / seqlen_padding_factor_ *\n                   seqlen_padding_factor_;\n    }\n    // Copy input token ids to device.\n    DLDataType dtype(DataType::Int(32));\n    Tensor token_ids_nd;\n    {\n      NVTXScopedRange nvtx_scope(\"Allocate token_ids at offset\");\n      token_ids_nd = token_ids_storage_->AllocTensor(offset * 4, {num_tokens}, dtype);\n      int* p_token_ids = static_cast<int*>(token_ids_nd->data) + (token_ids_nd->byte_offset) / 4;\n      for (int i = 0; i < static_cast<int>(token_ids.size()); ++i) {\n        p_token_ids[i] = token_ids[i];\n      }\n      for (int i = static_cast<int>(token_ids.size()); i < num_tokens; ++i) {\n        p_token_ids[i] = 0;\n      }\n    }\n    TVM_FFI_ICHECK_EQ(token_ids_nd->ndim, 1);\n    TVM_FFI_ICHECK_EQ(token_ids_nd->shape[0], num_tokens);\n    TVM_FFI_ICHECK_NE(prefill_chunk_size_, -1);\n    ObjectRef token_ids_dref_or_nd;\n    {\n      NVTXScopedRange nvtx_scope(\"Copy to worker 0\");\n      token_ids_dref_or_nd = ft_.CopyToWorker0(token_ids_nd, \"token_ids\", {prefill_chunk_size_});\n    }\n\n    ObjectRef embeddings = ft_.embed_func_(token_ids_dref_or_nd, params_).cast<ObjectRef>();\n    if (dst != nullptr) {\n      TVM_FFI_ICHECK(dst->defined());\n      ft_.nd_copy_embedding_to_offset_func_(embeddings, *dst, offset);\n      return *dst;\n    } else {\n      TVM_FFI_ICHECK_EQ(offset, 0);\n      return embeddings;\n    }\n  }\n\n  ObjectRef ImageEmbed(const Tensor& image, ObjectRef* dst, int offset) final {\n    NVTXScopedRange nvtx_scope(\"ImageEmbed\");\n    TVM_FFI_ICHECK(ft_.image_embed_func_.defined())\n        << \"`image_embed` function is not found in the model. \";\n\n    int tmp_h = 0, tmp_w = 0;\n    CalculateResizeShape(image, this->model_type_, &tmp_h, &tmp_w);\n    Shape resize_h = {tmp_h};\n    Shape resize_w = {tmp_w};\n\n    CalculateCropShape(image, this->model_type_, &tmp_h, &tmp_w);\n    Shape crop_h = {tmp_h};\n    Shape crop_w = {tmp_w};\n\n    auto image_dref_or_nd = ft_.CopyToWorker0(image, \"image\", image.Shape());\n    ObjectRef embeddings =\n        ft_.image_embed_func_(image_dref_or_nd, resize_h, resize_w, crop_h, crop_w, params_)\n            .cast<ObjectRef>();\n    if (dst != nullptr) {\n      TVM_FFI_ICHECK(dst->defined());\n      ft_.nd_copy_embedding_to_offset_func_(embeddings, *dst, offset);\n      return *dst;\n    } else {\n      TVM_FFI_ICHECK_EQ(offset, 0);\n      return embeddings;\n    }\n  }\n\n  bool CanGetLogits() final {\n    return ft_.get_logits_func_.defined() && ft_.batch_get_logits_func_.defined();\n  }\n\n  Tensor GetLogits(const ObjectRef& hidden_states) final {\n    NVTXScopedRange nvtx_scope(\"GetLogits\");\n    TVM_FFI_ICHECK(ft_.get_logits_func_.defined())\n        << \"`get_logits` function is not found in the model.\";\n\n    ObjectRef hidden_states_dref_or_nd{nullptr};\n    if (!ft_.use_disco && hidden_states->IsInstance<DRefObj>()) {\n      hidden_states_dref_or_nd =\n          Downcast<DRef>(hidden_states)->DebugGetFromRemote(0).cast<ObjectRef>();\n    } else {\n      hidden_states_dref_or_nd = hidden_states;\n    }\n    ObjectRef ret = ft_.get_logits_func_(hidden_states_dref_or_nd, params_).cast<ObjectRef>();\n    if (trace_enabled_) {\n      DeviceAPI::Get(device_)->StreamSync(device_, nullptr);\n    }\n    Tensor logits{nullptr};\n    if (ft_.use_disco) {\n      logits = Downcast<DRef>(ret)->DebugGetFromRemote(0).cast<Tensor>();\n    } else {\n      logits = Downcast<Tensor>(ret);\n    }\n    // logits: (b * s, v)\n    return logits;\n  }\n\n  Array<Tensor> GetMultiStepLogits(const ObjectRef& hidden_states) final {\n    NVTXScopedRange nvtx_scope(\"GetMultiStepLogits\");\n    TVM_FFI_ICHECK(ft_.get_logits_func_.defined())\n        << \"`get_logits` function is not found in the model.\";\n\n    ObjectRef hidden_states_dref_or_nd{nullptr};\n    ObjectRef ret = ft_.get_logits_func_(hidden_states, params_).cast<ObjectRef>();\n    Array<Tensor> logits{nullptr};\n    if (ft_.use_disco) {\n      logits = Downcast<DRef>(ret)->DebugGetFromRemote(0).cast<Array<Tensor>>();\n    } else {\n      logits = Downcast<Array<Tensor>>(ret);\n    }\n    return logits;\n  }\n\n  ObjectRef FuseEmbedHidden(const ObjectRef& embeddings, const ObjectRef& previous_hidden_states,\n                            int batch_size, int seq_len) final {\n    NVTXScopedRange nvtx_scope(\"FuseEmbedHidden\");\n\n    ObjectRef embeddings_dref_or_nd{nullptr};\n    if (!embeddings->IsInstance<DRefObj>()) {\n      // embeddings: (n, h)\n      Tensor embeddings_nd = Downcast<Tensor>(embeddings);\n      TVM_FFI_ICHECK_NE(hidden_size_, -1);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->ndim, 2);\n      TVM_FFI_ICHECK_GE(embeddings_nd->shape[0], batch_size * seq_len);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->shape[1], hidden_size_);\n      embeddings_dref_or_nd =\n          embeddings_nd.CreateView({batch_size * seq_len, hidden_size_}, embeddings_nd->dtype);\n    } else {\n      Shape embedding_shape{batch_size * seq_len, hidden_size_};\n      embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape).cast<ObjectRef>();\n    }\n\n    ObjectRef previous_hidden_states_dref_or_nd{nullptr};\n    if (!ft_.use_disco && previous_hidden_states->IsInstance<DRefObj>()) {\n      previous_hidden_states_dref_or_nd =\n          Downcast<DRef>(previous_hidden_states)->DebugGetFromRemote(0).cast<ObjectRef>();\n    } else {\n      previous_hidden_states_dref_or_nd = previous_hidden_states;\n    }\n    ObjectRef fused = ft_.fuse_embed_hidden_func_(embeddings_dref_or_nd,\n                                                  previous_hidden_states_dref_or_nd, params_)\n                          .cast<ObjectRef>();\n    if (trace_enabled_) {\n      DeviceAPI::Get(device_)->StreamSync(device_, nullptr);\n    }\n    Shape out_shape{batch_size, seq_len, hidden_size_};\n    if (ft_.use_disco) {\n      return ft_.nd_view_func_(fused, out_shape).cast<ObjectRef>();\n    } else {\n      Tensor fused_nd = Downcast<Tensor>(fused);\n      TVM_FFI_ICHECK_EQ(fused_nd->ndim, 2);\n      TVM_FFI_ICHECK_EQ(fused_nd->shape[0], batch_size * seq_len);\n      return fused_nd.CreateView(out_shape, fused_nd->dtype);\n    }\n  }\n\n  Tensor BatchPrefill(const ObjectRef& embeddings, const std::vector<int64_t>& seq_ids,\n                      const std::vector<int>& lengths) final {\n    TVM_FFI_ICHECK(!seq_ids.empty());\n    TVM_FFI_ICHECK_EQ(seq_ids.size(), lengths.size());\n    int num_sequences = seq_ids.size();\n    int total_length = 0;\n\n    int* p_logit_pos = static_cast<int*>(logit_pos_arr_->data);\n    for (int i = 0; i < num_sequences; ++i) {\n      total_length += lengths[i];\n      p_logit_pos[i] = total_length - 1;\n    }\n    bool padded = total_length % seqlen_padding_factor_ != 0;\n    if (padded) {\n      total_length = (total_length + seqlen_padding_factor_ - 1) / seqlen_padding_factor_ *\n                     seqlen_padding_factor_;\n    }\n    NVTXScopedRange nvtx_scope(\"BatchPrefill num_seq=\" + std::to_string(num_sequences) +\n                               \" total_len=\" + std::to_string(total_length));\n    Tensor logit_pos_nd = logit_pos_arr_.CreateView({num_sequences}, DataType::Int(32));\n\n    TVM_FFI_ICHECK(ft_.prefill_func_.defined())\n        << \"`prefill_with_embed` function is not found in the model. Please make sure the model is \"\n           \"compiled with flag `--sep-embed` and `--enable-batching`\";\n    TVM_FFI_ICHECK(ft_.kv_cache_begin_forward_func_.defined());\n    TVM_FFI_ICHECK(ft_.kv_cache_end_forward_func_.defined());\n    TVM_FFI_ICHECK(kv_cache_.defined()) << \"KV cache has not been initialized.\";\n\n    // Begin forward with the sequence ids and new lengths.\n    IntTuple seq_ids_tuple(seq_ids);\n    IntTuple lengths_tuple(lengths.begin(), lengths.end());\n    ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple);\n\n    ObjectRef embeddings_dref_or_nd;\n    if (!embeddings->IsInstance<DRefObj>()) {\n      // embeddings: (1, n, h)\n      Tensor embeddings_nd = Downcast<Tensor>(embeddings);\n      TVM_FFI_ICHECK_NE(hidden_size_, -1);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->ndim, 2);\n      TVM_FFI_ICHECK_GE(embeddings_nd->shape[0], total_length);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->shape[1], hidden_size_);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id);\n      embeddings_dref_or_nd =\n          embeddings_nd.CreateView({1, total_length, hidden_size_}, embeddings_nd->dtype);\n    } else {\n      Shape embedding_shape{1, total_length, hidden_size_};\n      embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape).cast<ObjectRef>();\n    }\n    TVM_FFI_ICHECK_NE(max_num_sequence_, -1);\n    ObjectRef logit_pos_dref_or_nd =\n        ft_.CopyToWorker0(logit_pos_nd, \"logit_pos\", {max_num_sequence_});\n\n    Function single_batch_prefill_func = ft_.single_batch_prefill_func_;\n    Function prefill_func = ft_.prefill_func_;\n    if (ft_.single_batch_extend_func_.defined()) {\n      TVM_FFI_ICHECK(ft_.extend_func_.defined())\n          << \"`batch_extend` function is not found in the model.\";\n      bool has_existing_sequence = false;\n      for (int64_t seq_id : seq_ids) {\n        if (prefilled_seq_ids_.count(seq_id)) {\n          has_existing_sequence = true;\n          break;\n        }\n      }\n      if (has_existing_sequence) {\n        single_batch_prefill_func = ft_.single_batch_extend_func_;\n        prefill_func = ft_.extend_func_;\n      }\n\n      for (int64_t seq_id : seq_ids) {\n        prefilled_seq_ids_.insert(seq_id);\n      }\n    }\n\n    // args: embeddings, logit_pos, kv_cache, params\n    ObjectRef ret;\n    if (seq_ids.size() == 1 && !padded) {\n      ret = single_batch_prefill_func(embeddings_dref_or_nd, kv_cache_, params_).cast<ObjectRef>();\n    } else {\n      ret = prefill_func(embeddings_dref_or_nd, logit_pos_dref_or_nd, kv_cache_, params_)\n                .cast<ObjectRef>();\n    }\n    Tensor logits;\n    if (ft_.use_disco) {\n      ret = ft_.tuple_getitem_func_(ret, 0).cast<ObjectRef>();\n      if (num_stages_ > 1) {\n        // Send the result from the last worker group to worker 0.\n        Shape shape{1, num_sequences, vocab_size_};\n        DataType dtype = DataType::Float(32);\n        ret = ft_.last_group_send_to_worker_0_(ret, disco_logits_arr_, shape, dtype)\n                  .cast<ObjectRef>();\n      }\n      logits = Downcast<DRef>(ret)->DebugGetFromRemote(0).cast<Tensor>();\n    } else {\n      logits = Downcast<Array<Tensor>>(ret)[0];\n    }\n    if (trace_enabled_) {\n      DeviceAPI::Get(device_)->StreamSync(device_, nullptr);\n    }\n    ft_.kv_cache_end_forward_func_(kv_cache_);\n\n    // logits: (1, num_sequences, v)\n    TVM_FFI_ICHECK_EQ(logits->ndim, 3);\n    TVM_FFI_ICHECK_EQ(logits->shape[0], 1);\n    TVM_FFI_ICHECK_EQ(logits->shape[1], num_sequences);\n    return logits;\n  }\n\n  ObjectRef BatchPrefillToLastHidden(const ObjectRef& embedding_or_hidden_states,\n                                     const std::vector<int64_t>& seq_ids,\n                                     const std::vector<int>& lengths) final {\n    NVTXScopedRange nvtx_scope(\"BatchPrefillToLastHidden\");\n    TVM_FFI_ICHECK(!seq_ids.empty());\n    TVM_FFI_ICHECK_EQ(seq_ids.size(), lengths.size());\n    int num_sequences = seq_ids.size();\n    int total_length = 0;\n\n    for (int i = 0; i < num_sequences; ++i) {\n      total_length += lengths[i];\n    }\n\n    ObjectRef embedding_or_hidden_states_dref_or_nd{nullptr};\n    Shape hidden_states_shape{1, total_length, hidden_size_};\n    if (!ft_.use_disco) {\n      Tensor embedding_or_hidden_states_nd = Downcast<Tensor>(embedding_or_hidden_states);\n      embedding_or_hidden_states_dref_or_nd = embedding_or_hidden_states_nd.CreateView(\n          hidden_states_shape, embedding_or_hidden_states_nd->dtype);\n    } else {\n      embedding_or_hidden_states_dref_or_nd =\n          ft_.nd_view_func_(embedding_or_hidden_states, hidden_states_shape).cast<ObjectRef>();\n    }\n\n    TVM_FFI_ICHECK(ft_.prefill_to_last_hidden_func_.defined())\n        << \"`prefill_to_last_hidden_states` function is not found in the model.\";\n    TVM_FFI_ICHECK(ft_.kv_cache_begin_forward_func_.defined());\n    TVM_FFI_ICHECK(ft_.kv_cache_end_forward_func_.defined());\n    TVM_FFI_ICHECK(kv_cache_.defined()) << \"KV cache has not been initialized.\";\n\n    // Begin forward with the sequence ids and new lengths.\n    IntTuple seq_ids_tuple(seq_ids);\n    IntTuple lengths_tuple(lengths.begin(), lengths.end());\n    ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple);\n\n    // args: embeddings, logit_pos, kv_cache, params\n    ObjectRef result{nullptr};\n    if (seq_ids.size() == 1) {\n      TVM_FFI_ICHECK(ft_.single_batch_prefill_to_last_hidden_func_.defined())\n          << \"`single_batch_prefill_to_last_hidden_states` function is not found in the model.\";\n      result = ft_.single_batch_prefill_to_last_hidden_func_(embedding_or_hidden_states_dref_or_nd,\n                                                             kv_cache_, params_)\n                   .cast<ObjectRef>();\n    } else {\n      result = ft_.prefill_to_last_hidden_func_(embedding_or_hidden_states_dref_or_nd, kv_cache_,\n                                                params_)\n                   .cast<ObjectRef>();\n    }\n    ObjectRef hidden_states = ft_.tuple_getitem_func_(result, 0).cast<ObjectRef>();\n\n    if (trace_enabled_) {\n      DeviceAPI::Get(device_)->StreamSync(device_, nullptr);\n    }\n    ft_.kv_cache_end_forward_func_(kv_cache_);\n\n    Shape out_shape{total_length, hidden_size_};\n    if (ft_.use_disco) {\n      return ft_.nd_view_func_(hidden_states, out_shape).cast<ObjectRef>();\n    } else {\n      Tensor hidden_states_nd = Downcast<Tensor>(hidden_states);\n      TVM_FFI_ICHECK_EQ(hidden_states_nd->ndim, 3);\n      TVM_FFI_ICHECK_EQ(hidden_states_nd->shape[0], 1);\n      TVM_FFI_ICHECK_EQ(hidden_states_nd->shape[1], total_length);\n      TVM_FFI_ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_);\n      return hidden_states_nd.CreateView(out_shape, hidden_states_nd->dtype);\n    }\n  }\n\n  Tensor BatchDecode(const ObjectRef& embeddings, const std::vector<int64_t>& seq_ids) final {\n    NVTXScopedRange nvtx_scope(\"BatchDecode num_seqs=\" + std::to_string(seq_ids.size()));\n    int num_sequence = seq_ids.size();\n\n    TVM_FFI_ICHECK(ft_.decode_func_.defined())\n        << \"`decode_with_embed` function is not found in the model. Please make sure the model is \"\n           \"compiled with flag `--sep-embed` and `--enable-batching`\";\n    TVM_FFI_ICHECK(ft_.kv_cache_begin_forward_func_.defined());\n    TVM_FFI_ICHECK(ft_.kv_cache_end_forward_func_.defined());\n    TVM_FFI_ICHECK(kv_cache_.defined()) << \"KV cache has not been initialized.\";\n\n    // Reserve in KV cache for the lengths of the input.\n    // Begin forward with the sequence ids and new lengths.\n    IntTuple seq_ids_tuple(seq_ids);\n    IntTuple lengths_tuple(std::vector<int64_t>(/*n=*/seq_ids.size(), /*v=*/1));\n    ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple);\n\n    ObjectRef embeddings_dref_or_nd;\n    if (!embeddings->IsInstance<DRefObj>()) {\n      // embeddings: (1, b, h)\n      Tensor embeddings_nd = Downcast<Tensor>(embeddings);\n      TVM_FFI_ICHECK_NE(hidden_size_, -1);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->ndim, 2);\n      TVM_FFI_ICHECK_GE(embeddings_nd->shape[0], num_sequence);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->shape[1], hidden_size_);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id);\n      embeddings_dref_or_nd =\n          embeddings_nd.CreateView({num_sequence, 1, hidden_size_}, embeddings_nd->dtype);\n    } else {\n      Shape embedding_shape{num_sequence, 1, hidden_size_};\n      embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape).cast<ObjectRef>();\n    }\n\n    // args: embeddings, kv_cache, params\n    ObjectRef ret;\n    if (seq_ids.size() == 1) {\n      ret = ft_.single_batch_decode_func_(embeddings_dref_or_nd, kv_cache_, params_)\n                .cast<ObjectRef>();\n    } else {\n      ret = ft_.decode_func_(embeddings_dref_or_nd, kv_cache_, params_).cast<ObjectRef>();\n    }\n    Tensor logits;\n    if (ft_.use_disco) {\n      ret = ft_.tuple_getitem_func_(ret, 0).cast<ObjectRef>();\n      if (num_stages_ > 1) {\n        // Send the result from the last worker group to worker 0.\n        Shape shape{num_sequence, 1, vocab_size_};\n        DataType dtype = DataType::Float(32);\n        ret = ft_.last_group_send_to_worker_0_(ret, disco_logits_arr_, shape, dtype)\n                  .cast<ObjectRef>();\n      }\n      logits = Downcast<DRef>(ret)->DebugGetFromRemote(0).cast<Tensor>();\n    } else {\n      logits = Downcast<Array<Tensor>>(ret)[0];\n    }\n    if (trace_enabled_) {\n      DeviceAPI::Get(device_)->StreamSync(device_, nullptr);\n    }\n    ft_.kv_cache_end_forward_func_(kv_cache_);\n\n    // logits: (b, 1, v)\n    TVM_FFI_ICHECK_EQ(logits->ndim, 3);\n    TVM_FFI_ICHECK_EQ(logits->shape[0], num_sequence);\n    TVM_FFI_ICHECK_EQ(logits->shape[1], 1);\n    return logits;\n  }\n\n  Tensor BatchTreeDecode(const ObjectRef& embeddings, const std::vector<int64_t>& seq_ids,\n                         const std::vector<int>& lengths,\n                         const std::vector<int64_t>& token_tree_parent_ptr) {\n    // This is similar to BatchDecode, except that it takes 'length', so that each sequence can have\n    // multiple leaf nodes for decoding.\n    NVTXScopedRange nvtx_scope(\"BatchTreeDecode num_seqs=\" + std::to_string(seq_ids.size()));\n    int num_sequence = seq_ids.size();\n    int total_length = 0;\n    for (int i = 0; i < num_sequence; ++i) {\n      total_length += lengths[i];\n    }\n    TVM_FFI_ICHECK_EQ(total_length, token_tree_parent_ptr.size());\n\n    TVM_FFI_ICHECK(ft_.decode_func_.defined())\n        << \"`tree_decode_with_embed` function is not found in the model. Please make sure the \"\n           \"model \"\n           \"is compiled with flag `--sep-embed` and `--enable-batching`\";\n    TVM_FFI_ICHECK(ft_.kv_cache_begin_forward_func_.defined());\n    TVM_FFI_ICHECK(ft_.kv_cache_end_forward_func_.defined());\n    TVM_FFI_ICHECK(kv_cache_.defined()) << \"KV cache has not been initialized.\";\n\n    // Reserve in KV cache for the lengths of the input.\n    // Begin forward with the sequence ids and new lengths.\n    IntTuple seq_ids_tuple(seq_ids);\n    IntTuple lengths_tuple(lengths.begin(), lengths.end());\n    IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr);\n    ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple,\n                                     token_tree_parent_ptr_tuple);\n\n    ObjectRef embeddings_dref_or_nd;\n    if (!embeddings->IsInstance<DRefObj>()) {\n      // embeddings: (1, n, h)\n      Tensor embeddings_nd = Downcast<Tensor>(embeddings);\n      TVM_FFI_ICHECK_NE(hidden_size_, -1);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->ndim, 2);\n      TVM_FFI_ICHECK_GE(embeddings_nd->shape[0], total_length);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->shape[1], hidden_size_);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id);\n      embeddings_dref_or_nd =\n          embeddings_nd.CreateView({total_length, 1, hidden_size_}, embeddings_nd->dtype);\n    } else {\n      Shape embedding_shape{total_length, 1, hidden_size_};\n      embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape).cast<ObjectRef>();\n    }\n\n    // same as BatchDecode\n    ObjectRef ret;\n    if (0 && seq_ids.size() == 1) {\n      ret = ft_.single_batch_decode_func_(embeddings_dref_or_nd, kv_cache_, params_)\n                .cast<ObjectRef>();\n    } else {\n      ret = ft_.decode_func_(embeddings_dref_or_nd, kv_cache_, params_).cast<ObjectRef>();\n    }\n    Tensor logits;\n    if (ft_.use_disco) {\n      Array<ObjectRef> result = Downcast<DRef>(ret)->DebugGetFromRemote(0).cast<Array<ObjectRef>>();\n      logits = Downcast<Tensor>(result[0]);\n    } else {\n      logits = Downcast<Array<Tensor>>(ret)[0];\n    }\n    if (trace_enabled_) {\n      DeviceAPI::Get(device_)->StreamSync(device_, nullptr);\n    }\n    ft_.kv_cache_end_forward_func_(kv_cache_);\n\n    // logits: (b, 1, v)\n    TVM_FFI_ICHECK_EQ(logits->ndim, 3);\n    TVM_FFI_ICHECK_EQ(logits->shape[0], total_length);\n    TVM_FFI_ICHECK_EQ(logits->shape[1], 1);\n    return logits;\n  }\n\n  ObjectRef BatchDecodeToLastHidden(const ObjectRef& hidden_states_dref_or_nd,\n                                    const std::vector<int64_t>& seq_ids) final {\n    NVTXScopedRange nvtx_scope(\"BatchDecodeToLastHidden num_seqs=\" +\n                               std::to_string(seq_ids.size()));\n    int num_sequence = seq_ids.size();\n\n    TVM_FFI_ICHECK(ft_.decode_to_last_hidden_func_.defined())\n        << \"`batch_decode_to_last_hidden_states` function is not found in the model.\";\n    TVM_FFI_ICHECK(ft_.kv_cache_begin_forward_func_.defined());\n    TVM_FFI_ICHECK(ft_.kv_cache_end_forward_func_.defined());\n    TVM_FFI_ICHECK(kv_cache_.defined()) << \"KV cache has not been initialized.\";\n\n    // Reserve in KV cache for the lengths of the input.\n    // Begin forward with the sequence ids and new lengths.\n    IntTuple seq_ids_tuple(seq_ids);\n    IntTuple lengths_tuple(std::vector<int64_t>(/*n=*/seq_ids.size(), /*v=*/1));\n    ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple);\n\n    // args: embeddings, kv_cache, params\n    ObjectRef result{nullptr};\n    if (seq_ids.size() == 1) {\n      TVM_FFI_ICHECK(ft_.single_batch_decode_to_last_hidden_func_.defined())\n          << \"`decode_to_last_hidden_states` function is not found in the model.\";\n      result =\n          ft_.single_batch_decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_)\n              .cast<ObjectRef>();\n    } else {\n      result = ft_.decode_to_last_hidden_func_(hidden_states_dref_or_nd, kv_cache_, params_)\n                   .cast<ObjectRef>();\n    }\n    ft_.kv_cache_end_forward_func_(kv_cache_);\n    ObjectRef hidden_states = ft_.tuple_getitem_func_(result, 0).cast<ObjectRef>();\n\n    if (trace_enabled_) {\n      DeviceAPI::Get(device_)->StreamSync(device_, nullptr);\n    }\n\n    // hidden_states: (b, 1, v) to (b, v)\n    Shape out_shape{num_sequence, hidden_size_};\n    if (ft_.use_disco) {\n      return ft_.nd_view_func_(hidden_states, out_shape).cast<ObjectRef>();\n    } else {\n      Tensor hidden_states_nd = Downcast<Tensor>(hidden_states);\n      TVM_FFI_ICHECK_EQ(hidden_states_nd->ndim, 3);\n      TVM_FFI_ICHECK_EQ(hidden_states_nd->shape[0], num_sequence);\n      TVM_FFI_ICHECK_EQ(hidden_states_nd->shape[1], 1);\n      TVM_FFI_ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_);\n      return hidden_states_nd.CreateView(out_shape, hidden_states_nd->dtype);\n    }\n  }\n\n  Tensor BatchVerify(const ObjectRef& embeddings, const std::vector<int64_t>& seq_ids,\n                     const std::vector<int>& lengths,\n                     const std::vector<int64_t>& token_tree_parent_ptr) final {\n    TVM_FFI_ICHECK(!seq_ids.empty());\n    TVM_FFI_ICHECK_EQ(seq_ids.size(), lengths.size());\n    int num_sequences = seq_ids.size();\n    int total_length = 0;\n    for (int i = 0; i < num_sequences; ++i) {\n      total_length += lengths[i];\n    }\n    TVM_FFI_ICHECK_EQ(total_length, token_tree_parent_ptr.size());\n\n    NVTXScopedRange nvtx_scope(\"BatchVerify num_tokens=\" + std::to_string(total_length));\n\n    TVM_FFI_ICHECK(ft_.verify_func_.defined())\n        << \"`verify_with_embed` function is not found in the model. Please make sure the model is \"\n           \"compiled with flag `--sep-embed` and `--enable-batching`\";\n    TVM_FFI_ICHECK(ft_.kv_cache_begin_forward_func_.defined());\n    TVM_FFI_ICHECK(ft_.kv_cache_end_forward_func_.defined());\n    TVM_FFI_ICHECK(kv_cache_.defined()) << \"KV cache has not been initialized.\";\n\n    // Begin forward with the sequence ids and new lengths.\n    IntTuple seq_ids_tuple(seq_ids);\n    IntTuple lengths_tuple(lengths.begin(), lengths.end());\n    IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr);\n    ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple,\n                                     token_tree_parent_ptr_tuple);\n\n    ObjectRef embeddings_dref_or_nd;\n    if (!embeddings->IsInstance<DRefObj>()) {\n      // embeddings: (1, n, h)\n      Tensor embeddings_nd = Downcast<Tensor>(embeddings);\n      TVM_FFI_ICHECK_NE(hidden_size_, -1);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->ndim, 2);\n      TVM_FFI_ICHECK_GE(embeddings_nd->shape[0], total_length);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->shape[1], hidden_size_);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id);\n      embeddings_dref_or_nd =\n          embeddings_nd.CreateView({1, total_length, hidden_size_}, embeddings_nd->dtype);\n    } else {\n      Shape embedding_shape{1, total_length, hidden_size_};\n      embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape).cast<ObjectRef>();\n    }\n    // args: embeddings, logit_pos, kv_cache, params\n    ObjectRef ret = ft_.verify_func_(embeddings_dref_or_nd, kv_cache_, params_).cast<ObjectRef>();\n    Tensor logits;\n    if (ft_.use_disco) {\n      ret = ft_.tuple_getitem_func_(ret, 0).cast<ObjectRef>();\n      if (num_stages_ > 1) {\n        // Send the result from the last worker group to worker 0.\n        Shape shape{1, total_length, vocab_size_};\n        DataType dtype = DataType::Float(32);\n        ret = ft_.last_group_send_to_worker_0_(ret, disco_logits_arr_, shape, dtype)\n                  .cast<ObjectRef>();\n      }\n      logits = Downcast<DRef>(ret)->DebugGetFromRemote(0).cast<Tensor>();\n    } else {\n      logits = Downcast<Array<Tensor>>(ret)[0];\n    }\n    if (trace_enabled_) {\n      DeviceAPI::Get(device_)->StreamSync(device_, nullptr);\n    }\n    ft_.kv_cache_end_forward_func_(kv_cache_);\n\n    // logits: (1, total_length, v)\n    TVM_FFI_ICHECK_EQ(logits->ndim, 3);\n    TVM_FFI_ICHECK_EQ(logits->shape[0], 1);\n    TVM_FFI_ICHECK_EQ(logits->shape[1], total_length);\n    return logits;\n  }\n\n  ObjectRef BatchVerifyToLastHidden(const ObjectRef& embeddings,\n                                    const std::vector<int64_t>& seq_ids,\n                                    const std::vector<int>& lengths,\n                                    const std::vector<int64_t>& token_tree_parent_ptr) final {\n    TVM_FFI_ICHECK(!seq_ids.empty());\n    TVM_FFI_ICHECK_EQ(seq_ids.size(), lengths.size());\n    int num_sequences = seq_ids.size();\n    int total_length = 0;\n    for (int i = 0; i < num_sequences; ++i) {\n      total_length += lengths[i];\n    }\n    TVM_FFI_ICHECK_EQ(total_length, token_tree_parent_ptr.size());\n    NVTXScopedRange nvtx_scope(\"BatchVerifyToLastHidden num_tokens=\" +\n                               std::to_string(total_length));\n\n    TVM_FFI_ICHECK(ft_.verify_to_last_hidden_func_.defined())\n        << \"`batch_verify_to_last_hidden_states` function is not found in the model.\";\n    TVM_FFI_ICHECK(ft_.kv_cache_begin_forward_func_.defined());\n    TVM_FFI_ICHECK(ft_.kv_cache_end_forward_func_.defined());\n    TVM_FFI_ICHECK(kv_cache_.defined()) << \"KV cache has not been initialized.\";\n\n    ObjectRef embeddings_dref_or_nd;\n    if (!embeddings->IsInstance<DRefObj>()) {\n      // embeddings: (1, n, h)\n      Tensor embeddings_nd = Downcast<Tensor>(embeddings);\n      TVM_FFI_ICHECK_NE(hidden_size_, -1);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->ndim, 2);\n      TVM_FFI_ICHECK_GE(embeddings_nd->shape[0], total_length);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->shape[1], hidden_size_);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->device.device_type, device_.device_type);\n      TVM_FFI_ICHECK_EQ(embeddings_nd->device.device_id, device_.device_id);\n      embeddings_dref_or_nd =\n          embeddings_nd.CreateView({1, total_length, hidden_size_}, embeddings_nd->dtype);\n    } else {\n      Shape embedding_shape{1, total_length, hidden_size_};\n      embeddings_dref_or_nd = ft_.nd_view_func_(embeddings, embedding_shape).cast<ObjectRef>();\n    }\n    // Begin forward with the sequence ids and new lengths.\n    IntTuple seq_ids_tuple(seq_ids);\n    IntTuple lengths_tuple(lengths.begin(), lengths.end());\n    IntTuple token_tree_parent_ptr_tuple(token_tree_parent_ptr);\n    ft_.kv_cache_begin_forward_func_(kv_cache_, seq_ids_tuple, lengths_tuple,\n                                     token_tree_parent_ptr_tuple);\n\n    // args: embeddings, logit_pos, kv_cache, params\n    ObjectRef result = ft_.verify_to_last_hidden_func_(embeddings_dref_or_nd, kv_cache_, params_)\n                           .cast<ObjectRef>();\n    ft_.kv_cache_end_forward_func_(kv_cache_);\n    ObjectRef hidden_states = ft_.tuple_getitem_func_(result, 0).cast<ObjectRef>();\n    if (trace_enabled_) {\n      DeviceAPI::Get(device_)->StreamSync(device_, nullptr);\n    }\n\n    Shape out_shape{total_length, hidden_size_};\n    if (!ft_.use_disco) {\n      Tensor hidden_states_nd = Downcast<Tensor>(hidden_states);\n      TVM_FFI_ICHECK_EQ(hidden_states_nd->ndim, 3);\n      TVM_FFI_ICHECK_EQ(hidden_states_nd->shape[0], 1);\n      TVM_FFI_ICHECK_EQ(hidden_states_nd->shape[1], total_length);\n      TVM_FFI_ICHECK_EQ(hidden_states_nd->shape[2], hidden_size_);\n      return hidden_states_nd.CreateView(out_shape, hidden_states_nd->dtype);\n    } else {\n      return ft_.nd_view_func_(hidden_states, out_shape).cast<ObjectRef>();\n    }\n  }\n\n  /*********************** KV Cache Management  ***********************/\n\n  void CreateKVCache(int page_size, int max_num_sequence, int64_t max_total_sequence_length,\n                     int64_t prefill_chunk_size, int max_history_size) final {\n    KVStateKind kv_state_kind = GetMetadata().kv_state_kind;\n    if (kv_state_kind == KVStateKind::kKVCache) {\n      IntTuple max_num_sequence_tuple{max_num_sequence};\n      IntTuple max_total_sequence_length_tuple{max_total_sequence_length};\n      IntTuple prefill_chunk_size_tuple{prefill_chunk_size};\n      IntTuple page_size_tuple{page_size};\n      IntTuple support_sliding_window{sliding_window_size_ != -1};\n      kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_total_sequence_length_tuple,\n                                            prefill_chunk_size_tuple, page_size_tuple,\n                                            support_sliding_window)\n                      .cast<ObjectRef>();\n      local_kv_cache_ = ft_.use_disco\n                            ? Downcast<DRef>(kv_cache_)->DebugGetFromRemote(0).cast<ObjectRef>()\n                            : kv_cache_;\n    } else if (kv_state_kind == KVStateKind::kRNNState) {\n      IntTuple max_num_sequence_tuple{max_num_sequence};\n      IntTuple max_history_size_tuple = {std::max(max_history_size, 1)};\n      kv_cache_ = ft_.create_kv_cache_func_(max_num_sequence_tuple, max_history_size_tuple)\n                      .cast<ObjectRef>();\n      local_kv_cache_ = ft_.use_disco\n                            ? Downcast<DRef>(kv_cache_)->DebugGetFromRemote(0).cast<ObjectRef>()\n                            : kv_cache_;\n    } else if (kv_state_kind == KVStateKind::kNone) {\n      // Do nothing\n    } else {\n      LOG(FATAL) << \"Unknown kv_state_kind: \" << static_cast<int>(kv_state_kind);\n    }\n  }\n\n  void AddNewSequence(int64_t seq_id) final {\n    if (ft_.model_metadata_.kv_state_kind == KVStateKind::kNone) {\n      return;\n    }\n    ft_.kv_cache_add_sequence_func_(kv_cache_, seq_id);\n  }\n\n  void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos) final {\n    if (ft_.model_metadata_.kv_state_kind == KVStateKind::kNone) {\n      return;\n    }\n    ft_.kv_cache_fork_sequence_func_(kv_cache_, parent_seq_id, child_seq_id, fork_pos);\n    prefilled_seq_ids_.insert(child_seq_id);\n  }\n\n  void RemoveSequence(int64_t seq_id) final {\n    if (this->kind == KVStateKind::kNone) {\n      return;\n    }\n    prefilled_seq_ids_.erase(seq_id);\n    ft_.kv_cache_remove_sequence_func_(kv_cache_, seq_id);\n  }\n\n  void PopNFromKVCache(int64_t seq_id, int num_tokens) final {\n    if (this->kind == KVStateKind::kNone) {\n      return;\n    }\n    ft_.kv_cache_popn_func_(kv_cache_, seq_id, num_tokens);\n  }\n\n  void CommitAcceptedTokenTreeNodesToKVCache(\n      const std::vector<int64_t>& seq_ids,\n      const std::vector<int64_t>& accepted_leaf_indices) final {\n    IntTuple seq_ids_tuple(seq_ids);\n    IntTuple accepted_leaf_indices_tuple(accepted_leaf_indices);\n    ft_.kv_cache_commit_accepted_token_tree_nodes_func_(kv_cache_, seq_ids_tuple,\n                                                        accepted_leaf_indices_tuple);\n  }\n\n  void EnableSlidingWindowForSeq(int64_t seq_id) final {\n    if (this->kind == KVStateKind::kNone) {\n      return;\n    }\n    if (sliding_window_size_ != -1) {\n      ft_.kv_cache_enable_sliding_window_for_seq_(kv_cache_, seq_id, sliding_window_size_,\n                                                  attention_sink_size_);\n    }\n  }\n\n  IntTuple DisaggPrepareKVRecv(int64_t seq_id, int length) final {\n    NVTXScopedRange nvtx_scope(\"DisaggPrepareKVRecv length=\" + std::to_string(length));\n\n    TVM_FFI_ICHECK(ft_.kv_cache_disagg_prepare_recv_func_.defined());\n    TVM_FFI_ICHECK(kv_cache_.defined()) << \"KV cache has not been initialized.\";\n\n    // Run KV receive preparation.\n    ObjectRef ret;\n    ret = ft_.kv_cache_disagg_prepare_recv_func_(kv_cache_, seq_id, length).cast<ObjectRef>();\n    IntTuple compressed_kv_append_metadata;\n    if (ft_.use_disco) {\n      compressed_kv_append_metadata = Downcast<DRef>(ret)->DebugGetFromRemote(0).cast<IntTuple>();\n    } else {\n      compressed_kv_append_metadata = Downcast<IntTuple>(ret);\n    }\n\n    return compressed_kv_append_metadata;\n  }\n\n  void DisaggMarkKVSend(int64_t seq_id, int begin_pos, IntTuple compressed_kv_append_metadata,\n                        int dst_group_offset) final {\n    NVTXScopedRange nvtx_scope(\"DisaggMarkKVSend seq_id=\" + std::to_string(seq_id) +\n                               \" begin_pos=\" + std::to_string(begin_pos));\n\n    TVM_FFI_ICHECK(ft_.kv_cache_disagg_mark_send_func_.defined());\n    TVM_FFI_ICHECK(kv_cache_.defined()) << \"KV cache has not been initialized.\";\n\n    // Run KV send preparation.\n    ft_.kv_cache_disagg_mark_send_func_(kv_cache_, seq_id, begin_pos, compressed_kv_append_metadata,\n                                        dst_group_offset);\n  }\n\n  /************** Raw Info Query **************/\n\n  ModelMetadata GetMetadata() const final { return ft_.model_metadata_; }\n\n  int GetNumAvailablePages() const final {\n    if (this->kind == KVStateKind::kRNNState || this->kind == KVStateKind::kNone) {\n      // RNNState does not introduce new page at runtime\n      return std::numeric_limits<int>::max();\n    } else {\n      return ft_.kv_cache_get_num_available_pages_func_(local_kv_cache_).cast<int>();\n    }\n  }\n\n  int GetCurrentTotalSequenceLength() const final {\n    if (this->kind == KVStateKind::kRNNState || this->kind == KVStateKind::kNone) {\n      // RNNState does not have a total sequence length limit\n      return 0;\n    } else {\n      return ft_.kv_cache_get_total_sequence_length_func_(local_kv_cache_).cast<int>();\n    }\n  }\n\n  /*********************** Utilities  ***********************/\n\n  void LoadParams() final { this->params_ = ft_.LoadParams(model_, device_); }\n\n  void SetMaxNumSequence(int max_num_sequence) final {\n    this->max_num_sequence_ = max_num_sequence;\n    this->logit_pos_arr_ =\n        Tensor::Empty({max_num_sequence}, DataType::Int(32), Device{DLDeviceType::kDLCPU, 0});\n  }\n\n  void SetPrefillChunkSize(int prefill_chunk_size) final {\n    this->prefill_chunk_size_ = prefill_chunk_size;\n    Device preferred_host_device = GetPreferredHostDevice(device_);\n    memory::Allocator* allocator = memory::MemoryManager::GetOrCreateAllocator(\n        preferred_host_device, memory::AllocatorType::kNaive);\n    TVM_FFI_ICHECK_NOTNULL(allocator);\n    token_ids_storage_ = memory::Storage(\n        allocator->Alloc(preferred_host_device, {prefill_chunk_size_}, DataType::Int(32)),\n        allocator);\n    if (this->num_stages_ > 1) {\n      // Create a remote Tensor for logits when pipeline parallelism is enabled.\n      disco_logits_arr_ =\n          ft_.Empty({prefill_chunk_size_, vocab_size_}, DataType::Float(32), device_,\n                    /*worker0_only=*/true);\n    }\n  }\n\n  LogitProcessor CreateLogitProcessor(int max_num_token,\n                                      Optional<EventTraceRecorder> trace_recorder) final {\n    return LogitProcessor(max_num_token, vocab_size_, &this->ft_, device_,\n                          std::move(trace_recorder));\n  }\n\n  Sampler CreateSampler(int max_num_sample, int num_models,\n                        Optional<EventTraceRecorder> trace_recorder) final {\n    if (Sampler::SupportGPUSampler(device_)) {\n      return Sampler::CreateGPUSampler(max_num_sample, vocab_size_, &this->ft_, device_,\n                                       std::move(trace_recorder));\n    } else {\n      return Sampler::CreateCPUSampler(std::move(trace_recorder));\n    }\n  }\n\n  int EstimateHostCPURequirement() const final {\n    TVM_FFI_ICHECK_NE(num_shards_, -1) << \"The model has not been initialized\";\n    return num_shards_ > 1 ? num_shards_ : 0;\n  }\n\n  int GetSlidingWindowSize() const final { return sliding_window_size_; }\n\n  int GetAttentionSinkSize() const final { return attention_sink_size_; }\n\n  ObjectRef AllocEmbeddingTensor() final {\n    if (!ft_.alloc_embedding_tensor_func_.defined()) {\n      return ObjectRef{nullptr};\n    }\n    // Allocate the embedding tensor.\n    ObjectRef embedding = ft_.alloc_embedding_tensor_func_().cast<ObjectRef>();\n    // Get the shape of the embedding tensor for hidden size.\n    Shape embedding_shape;\n    if (ft_.use_disco) {\n      TVM_FFI_ICHECK(embedding->IsInstance<DRefObj>());\n      ObjectRef shape_ref = ft_.nd_get_shape_func_(embedding).cast<ObjectRef>();\n      embedding_shape = Downcast<DRef>(shape_ref)->DebugGetFromRemote(0).cast<Shape>();\n    } else {\n      Tensor embedding_nd = Downcast<Tensor>(embedding);\n      embedding_shape = embedding_nd.Shape();\n    }\n    TVM_FFI_ICHECK_NE(prefill_chunk_size_, -1);\n    TVM_FFI_ICHECK_EQ(embedding_shape.size(), 2);\n    TVM_FFI_ICHECK_GE(embedding_shape[0], prefill_chunk_size_);\n    this->hidden_size_ = embedding_shape[1];\n    return embedding;\n  }\n\n  ObjectRef AllocHiddenStatesTensor() final {\n    if (!ft_.alloc_embedding_tensor_func_.defined()) {\n      return ObjectRef{nullptr};\n    }\n    // Allocate the hidden_states tensor.\n    // Use the same function as embeddings.\n    ObjectRef hidden_states = ft_.alloc_embedding_tensor_func_().cast<ObjectRef>();\n    Tensor hidden_states_nd{nullptr};\n    // Get the shape of the hidden_states tensor for hidden size.\n    if (ft_.use_disco) {\n      TVM_FFI_ICHECK(hidden_states->IsInstance<DRefObj>());\n      hidden_states_nd = Downcast<DRef>(hidden_states)->DebugGetFromRemote(0).cast<Tensor>();\n    } else {\n      hidden_states_nd = Downcast<Tensor>(hidden_states);\n    }\n    Shape hidden_states_shape = hidden_states_nd.Shape();\n    TVM_FFI_ICHECK_NE(prefill_chunk_size_, -1);\n    TVM_FFI_ICHECK_EQ(hidden_states_shape.size(), 2);\n    TVM_FFI_ICHECK_GE(hidden_states_shape[0], prefill_chunk_size_);\n    this->hidden_size_ = hidden_states_shape[1];\n    this->hidden_states_dtype_ = hidden_states_nd->dtype;\n    return hidden_states;\n  }\n\n  void Reset() final {\n    // Reset the KV cache.\n    if (kv_cache_.defined()) {\n      ft_.reset_kv_cache_func_(kv_cache_);\n    }\n  }\n\n  /********************** Utilities for speculative decoding **********************/\n\n  DraftTokenWorkspaceManager CreateDraftTokenWorkspaceManager(int max_num_tokens) {\n    return DraftTokenWorkspaceManager(max_num_tokens, vocab_size_, hidden_size_,\n                                      hidden_states_dtype_, device_, ft_);\n  }\n\n  ObjectRef GatherHiddenStates(const ObjectRef& input, const std::vector<int>& indices,\n                               ObjectRef* dst) final {\n    ObjectRef dst_view{nullptr};\n    Shape out_shape{static_cast<int64_t>(indices.size()), hidden_size_};\n    if ((*dst)->IsInstance<DRefObj>()) {\n      dst_view = ft_.nd_view_func_(*dst, out_shape).cast<ObjectRef>();\n    } else {\n      Tensor dst_nd = Downcast<Tensor>(*dst);\n      dst_view = dst_nd.CreateView(out_shape, hidden_states_dtype_);\n    }\n    Tensor indices_nd =\n        logit_pos_arr_.CreateView({static_cast<int64_t>(indices.size())}, DataType::Int(32));\n    indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int));\n    TVM_FFI_ICHECK_NE(max_num_sequence_, -1);\n    ObjectRef indices_device = ft_.CopyToWorker0(indices_nd, \"logit_pos\", {max_num_sequence_});\n    ft_.gather_hidden_states_func_(input, indices_device, dst_view);\n    return dst_view;\n  }\n\n  void ScatterHiddenStates(const ObjectRef& input, const std::vector<int>& indices,\n                           ObjectRef* dst) final {\n    Tensor indices_nd =\n        logit_pos_arr_.CreateView({static_cast<int64_t>(indices.size())}, DataType::Int(32));\n    indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int));\n    TVM_FFI_ICHECK_NE(max_num_sequence_, -1);\n    ObjectRef indices_device = ft_.CopyToWorker0(indices_nd, \"logit_pos\", {max_num_sequence_});\n    ft_.scatter_hidden_states_func_(input, indices_device, *dst);\n  }\n\n  Tensor GatherDraftProbs(const Tensor& input, const std::vector<int>& indices, Tensor* dst) final {\n    Tensor dst_view =\n        dst->CreateView({static_cast<int64_t>(indices.size()), vocab_size_}, DataType::Float(32));\n    Tensor indices_nd =\n        logit_pos_arr_.CreateView({static_cast<int64_t>(indices.size())}, DataType::Int(32));\n    indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int));\n    TVM_FFI_ICHECK_NE(max_num_sequence_, -1);\n    ObjectRef indices_device =\n        ft_.CopyToWorker0(indices_nd, \"logit_pos_local\", {max_num_sequence_}, /*local_only=*/true);\n    ft_.gather_probs_func_(input, indices_device, dst_view);\n    return dst_view;\n  }\n\n  void ScatterDraftProbs(const Tensor& input, const std::vector<int>& indices, Tensor* dst) final {\n    Tensor indices_nd =\n        logit_pos_arr_.CreateView({static_cast<int64_t>(indices.size())}, DataType::Int(32));\n    indices_nd.CopyFromBytes(indices.data(), indices.size() * sizeof(int));\n    TVM_FFI_ICHECK_NE(max_num_sequence_, -1);\n    ObjectRef indices_device =\n        ft_.CopyToWorker0(indices_nd, \"logit_pos_local\", {max_num_sequence_}, /*local_only=*/true);\n    ft_.scatter_probs_func_(input, indices_device, *dst);\n  }\n\n  Array<Tensor> GetMedusaLogits(const ObjectRef& hidden_states) {\n    ObjectRef result = ft_.get_logits_func_(hidden_states).cast<ObjectRef>();\n    Array<Tensor> logits{nullptr};\n    if (ft_.use_disco) {\n      logits = Downcast<DRef>(result)->DebugGetFromRemote(0).cast<Array<Tensor>>();\n    } else {\n      logits = Downcast<Array<Tensor>>(result);\n    }\n    return logits;\n  }\n\n  /************** Debug/Profile **************/\n\n  void DebugCallFuncOnAllAllWorker(const String& func_name, Optional<String> func_args) final {\n    ft_.DebugCallFuncOnAllAllWorker(func_name, func_args);\n  }\n\n private:\n  /*! \\brief Load model configuration from JSON. */\n  void LoadModelConfigJSON(const tvm::ffi::json::Object& config) {\n    this->sliding_window_size_ =\n        json::LookupOrDefault<int64_t>(config, \"sliding_window_size\", this->sliding_window_size_);\n    TVM_FFI_ICHECK(sliding_window_size_ == -1 || sliding_window_size_ > 0)\n        << \"Sliding window should be either -1 (which means disabled) of positive\";\n    this->attention_sink_size_ =\n        json::LookupOrDefault<int64_t>(config, \"attention_sink_size\", this->attention_sink_size_);\n    this->attention_sink_size_ = std::max(this->attention_sink_size_, 0);\n    this->vocab_size_ = json::Lookup<int64_t>(config, \"vocab_size\");\n    this->model_type_ = json::Lookup<std::string>(config, \"model_type\");\n  }\n\n  //----------------------------\n  // Model configurations\n  //----------------------------\n  std::string model_;\n  int sliding_window_size_ = -1;\n  int attention_sink_size_ = 0;\n  int num_shards_ = -1;\n  int num_stages_ = -1;\n  int max_num_sequence_ = -1;\n  int prefill_chunk_size_ = -1;\n  int hidden_size_ = -1;\n  DLDataType hidden_states_dtype_;\n  int vocab_size_ = -1;\n  int image_embed_size_ = -1;\n  int seqlen_padding_factor_ = 1;\n  std::string model_type_;\n  //----------------------------\n  // TVM related states\n  //----------------------------\n  // Packed function table\n  FunctionTable ft_;\n  // Paged KV cache.\n  // - We use `kv_cache_` for general KV cache operations.\n  // When tensor parallelism is enabled, `kv_cache_` is a DRef object.\n  // - For efficient KV cache raw info query, we use `local_kv_cache`\n  // as a local **reference** of `kv_cache_`. It is a pure mirror of `kv_cache_`\n  // except that it is always a local object.\n  ObjectRef kv_cache_{nullptr};\n  ObjectRef local_kv_cache_{nullptr};\n  // Runtime device\n  Device device_;\n  // Model parameters\n  ObjectRef params_;\n  // Shared Tensor\n  memory::Storage token_ids_storage_{nullptr};\n  Tensor logit_pos_arr_{nullptr};\n  ObjectRef disco_logits_arr_{nullptr};\n  // A boolean indicating if tracing is enabled.\n  bool trace_enabled_;\n  // An enum indicating whether it's RNN-based.\n  KVStateKind kind;\n  // A set of sequence IDs that have been prefilled.\n  std::unordered_set<int64_t> prefilled_seq_ids_;\n};\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  namespace refl = tvm::ffi::reflection;\n  refl::GlobalDef().def(\n      \"mlc.copy_embedding_to_offset\", [](Tensor embedding, Tensor dst, int offset) {\n        // embedding: (m, hidden_size)\n        // dst: (prefill_chunk_size, hidden_size)\n        TVM_FFI_ICHECK_EQ(embedding->ndim, 2);\n        TVM_FFI_ICHECK_EQ(dst->ndim, 2);\n        TVM_FFI_ICHECK_LE(embedding->shape[0] + offset, dst->shape[0]);\n        TVM_FFI_ICHECK_EQ(embedding->shape[1], dst->shape[1]);\n        const DLTensor& copy_src = *(embedding.operator->());\n        const DLTensor* p_copy_dst = dst.operator->();\n        DLTensor copy_dst = *p_copy_dst;\n        copy_dst.shape = embedding->shape;\n        copy_dst.byte_offset = offset * embedding->shape[1] *\n                               ((embedding->dtype.bits * embedding->dtype.lanes + 7) / 8);\n        Tensor::CopyFromTo(&copy_src, &copy_dst);\n      });\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/model.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/model.h\n * \\brief The header for runtime module of LLM functions (prefill/decode/etc.)\n */\n\n#ifndef MLC_LLM_SERVE_MODEL_H_\n#define MLC_LLM_SERVE_MODEL_H_\n\n#include <tvm/ffi/extra/json.h>\n#include <tvm/ffi/string.h>\n#include <tvm/runtime/tensor.h>\n\n#include \"../base.h\"\n#include \"../support/result.h\"\n#include \"config.h\"\n#include \"draft_token_workspace_manager.h\"\n#include \"event_trace_recorder.h\"\n#include \"function_table.h\"\n#include \"logit_processor.h\"\n#include \"sampler/sampler.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing tvm::Device;\nusing namespace tvm::runtime;\n\n// Declare the sampler class for `Model::CreateSampler`.\nclass Sampler;\n\n/*!\n * \\brief The workspace tensors that may be shared across different\n * calls to Model. For example, the prefill action use the `embeddings`\n * workspace for the concatenated embeddings of different sequences.\n * The workspace tensor is created by Model but owned by engine.\n */\nstruct ModelWorkspace {\n  /*!\n   * \\brief The embedding tensor. It can be either an Tensor when tensor\n   * model parallelism is not enabled, or a DRef when using tensor model parallelism.\n   */\n  ObjectRef embeddings{nullptr};\n  /*!\n   * \\brief The hidden_states tensor for the current batch. It can be either an Tensor when tensor\n   * model parallelism is not enabled, or a DRef when using tensor model parallelism.\n   */\n  ObjectRef hidden_states{nullptr};\n\n  /*!\n   * \\brief The draft token probabilities tensor for the current batch.\n   */\n  Tensor draft_probs{nullptr};\n\n  /*!\n   * \\brief The hidden_states tensor storing the hidden_states of draft tokens of all requests.\n   */\n  ObjectRef draft_hidden_states_storage{nullptr};\n\n  /*!\n   * \\brief The draft token probabilities tensor storing the probabilities of draft tokens of all\n   * requests.\n   */\n  Tensor draft_probs_storage{nullptr};\n};\n\n/*!\n * \\brief The model module for LLM functions.\n * It runs an LLM, and has an internal KV cache that maintains\n * the history KV values of all processed tokens.\n *\n * It contains the following functions:\n *\n * Model related:\n * - \"token_embed\": take token ids as input and return the embeddings,\n * - \"batch_prefill\": take embedding of a single sequence\n * as input, forward the embedding through LLM and return the logits,\n * - \"decode\": take the embeddings of the last-committed token of an\n * entire batch as input, forward through LLM and return the logits\n * for all sequences in the batch,\n * - \"softmax_with_temperature\": take logits and temperatures, return\n * probabilities.\n *\n * KV cache related:\n * - \"create_kv_cache\": create the KV cache for this module,\n * - \"add_new_sequence\": add (declare) a new sequence in the KV cache,\n * - \"remove_sequence\": remove a sequence from KV cache.\n *\n * ... and some other auxiliary functions.\n */\nclass ModelObj : public Object {\n public:\n  /*********************** Model Computation  ***********************/\n\n  /*!\n   * \\brief Compute embeddings for the input token ids.\n   * When the input destination pointer is defined, it in-place writes the\n   * embedding into the input destination array at the given offset.\n   * Otherwise, the embeddings will be directly returned back.\n   * \\param token_ids The token ids to compute embedding for.\n   * \\param dst The destination array of the embedding lookup.\n   * \\param offset The token offset where the computed embeddings will be written\n   * into the destination array.\n   * \\return The updated destination embedding array or the computed embeddings.\n   * \\note When `dst` is undefined, we require `offset` to be 0.\n   */\n  virtual ObjectRef TokenEmbed(IntTuple batch_token_ids, ObjectRef* dst = nullptr,\n                               int offset = 0) = 0;\n\n  /*!\n   * \\brief Compute embeddings for the input image.\n   * \\param image The image to compute embedding for.\n   * \\return The computed embeddings.\n   */\n  virtual ObjectRef ImageEmbed(const Tensor& image, ObjectRef* dst = nullptr, int offset = 0) = 0;\n\n  /*!\n   * \\brief Fuse the embeddings and hidden_states.\n   * \\param embeddings The embedding of the input to be prefilled.\n   * \\param previous_hidden_states The hidden_states from previous base model.\n   * \\param batch_size Batch size.\n   * \\param seq_len Sequence length.\n   * \\return The fused hidden_states.\n   */\n  virtual ObjectRef FuseEmbedHidden(const ObjectRef& embeddings,\n                                    const ObjectRef& previous_hidden_states, int batch_size,\n                                    int seq_len) = 0;\n\n  /*!\n   * \\brief Return if the model has lm_head so that we can get logits.\n   */\n  virtual bool CanGetLogits() = 0;\n\n  /*!\n   * \\brief Compute logits for last hidden_states.\n   * \\param last_hidden_states The last hidden_states to compute logits for.\n   * \\return The computed logits.\n   */\n  virtual Tensor GetLogits(const ObjectRef& last_hidden_states) = 0;\n\n  virtual Array<Tensor> GetMultiStepLogits(const ObjectRef& last_hidden_states) = 0;\n\n  /*!\n   * \\brief Batch prefill function. Embedding in, logits out.\n   * The embedding order of sequences in `embedding_arr` follows\n   * the order of `seq_ids`.\n   * \\param embeddings The embedding of the input to be prefilled.\n   * \\param seq_id The id of the sequence in the KV cache.\n   * \\param lengths The length of each sequence to prefill.\n   * \\return The logits for the next token.\n   */\n  virtual Tensor BatchPrefill(const ObjectRef& embeddings, const std::vector<int64_t>& seq_ids,\n                              const std::vector<int>& lengths) = 0;\n\n  /*!\n   * \\brief Batch prefill function. Input hidden_states are computed from\n   * input embeddings and previous hidden_states, output last hidden_states.\n   * \\param hidden_states The hidden_states of the input to be prefilled.\n   * \\param seq_id The id of the sequence in the KV cache.\n   * \\param lengths The length of each sequence to prefill.\n   * \\return The hidden_states for the next token.\n   */\n  virtual ObjectRef BatchPrefillToLastHidden(const ObjectRef& hidden_states,\n                                             const std::vector<int64_t>& seq_ids,\n                                             const std::vector<int>& lengths) = 0;\n\n  /*!\n   * \\brief Batch decode function. Embedding in, logits out.\n   * The embedding order of sequences in `embeddings` follows\n   * the order of `seq_ids`.\n   * \\param embeddings The embedding of last generated token in the entire batch.\n   * \\param seq_id The id of the sequence in the KV cache.\n   * \\return The logits for the next token for each sequence in the batch.\n   */\n  virtual Tensor BatchDecode(const ObjectRef& embeddings, const std::vector<int64_t>& seq_ids) = 0;\n\n  virtual Tensor BatchTreeDecode(const ObjectRef& embeddings, const std::vector<int64_t>& seq_ids,\n                                 const std::vector<int>& lengths,\n                                 const std::vector<int64_t>& token_tree_parent_ptr) = 0;\n\n  /*!\n   * \\brief Batch decode function. Input hidden_states are computed from\n   * input embeddings and previous hidden_states, output last hidden_states.\n   * \\param hidden_states The hidden_states of last generated token in the entire batch.\n   * \\param seq_id The id of the sequence in the KV cache.\n   * \\return The hidden_states for the next token for each sequence in the batch.\n   */\n  virtual ObjectRef BatchDecodeToLastHidden(const ObjectRef& hidden_states,\n                                            const std::vector<int64_t>& seq_ids) = 0;\n\n  /*!\n   * \\brief Batch verify function. Embedding in, logits out.\n   * \\param embeddings The embedding of the input to be verified.\n   * \\param seq_id The id of the sequence in the KV cache.\n   * \\param lengths The length of each sequence to verify.\n   * \\param token_tree_parent_ptr The parent pointers of the token tree.\n   * It's size is the sum of \"lengths\". It contains a batch of independent trees,\n   * one for each sequence. Parent being \"-1\" means the node is a root.\n   * \\return The logits for the draft token for each sequence in the batch.\n   * \\note The function runs for **every** sequence in the batch.\n   * That is to say, it does not accept \"running a verify step for a subset\n   * of the full batch\".\n   */\n  virtual Tensor BatchVerify(const ObjectRef& embeddings, const std::vector<int64_t>& seq_ids,\n                             const std::vector<int>& lengths,\n                             const std::vector<int64_t>& token_tree_parent_ptr) = 0;\n\n  /*!\n   * \\brief Batch verify function. Input hidden_states are computed from\n   * input embeddings and previous hidden_states, output last hidden_states.\n   * \\param hidden_states The hidden_states of the input to be verified.\n   * \\param seq_id The id of the sequence in the KV cache.\n   * \\param lengths The length of each sequence to verify.\n   * \\param token_tree_parent_ptr The parent pointers of the token tree.\n   * It's size is the sum of \"lengths\". It contains a batch of independent trees,\n   * one for each sequence. Parent being \"-1\" means the node is a root.\n   * \\return The hidden_states for the draft token for each sequence in the batch.\n   * \\note The function runs for **every** sequence in the batch.\n   * That is to say, it does not accept \"running a verify step for a subset\n   * of the full batch\".\n   */\n  virtual ObjectRef BatchVerifyToLastHidden(const ObjectRef& hidden_states,\n                                            const std::vector<int64_t>& seq_ids,\n                                            const std::vector<int>& lengths,\n                                            const std::vector<int64_t>& token_tree_parent_ptr) = 0;\n\n  /*********************** KV Cache Management  ***********************/\n\n  /*!\n   * \\brief Create the KV cache inside the model with regard to the input config.\n   * \\param page_size The number of consecutive tokens handled in each page in paged KV cache.\n   * \\param max_num_sequence The maximum number of sequences that are allowed to be\n   * processed by the KV cache at any time.\n   * \\param max_total_sequence_length The maximum length allowed for a single sequence\n   * in the engine.\n   * \\param prefill_chunk_size The maximum total number of tokens whose KV data\n   * are allowed to exist in the KV cache at any time.\n   * \\param max_history_size The maximum history size for RNN state to roll back.\n   * The KV cache does not need this.\n   */\n  virtual void CreateKVCache(int page_size, int max_num_sequence, int64_t max_total_sequence_length,\n                             int64_t prefill_chunk_size, int max_history_size) = 0;\n\n  /*! \\brief Add a new sequence with the given sequence id to the KV cache. */\n  virtual void AddNewSequence(int64_t seq_id) = 0;\n\n  /*! \\brief Fork a sequence from a given parent sequence. */\n  virtual void ForkSequence(int64_t parent_seq_id, int64_t child_seq_id, int64_t fork_pos = -1) = 0;\n\n  /*! \\brief Remove the given sequence from the KV cache in the model. */\n  virtual void RemoveSequence(int64_t seq_id) = 0;\n\n  /*! \\brief Pop out N pages from KV cache. */\n  virtual void PopNFromKVCache(int64_t seq_id, int num_tokens) = 0;\n\n  /*!\n   * \\brief Commit the accepted token tree nodes to KV cache.\n   * The unaccepted token tree node will be removed from KV cache.\n   * This is usually used in the verification stage of speculative decoding.\n   */\n  virtual void CommitAcceptedTokenTreeNodesToKVCache(\n      const std::vector<int64_t>& seq_ids, const std::vector<int64_t>& accepted_leaf_indices) = 0;\n\n  /*!\n   * \\brief Enabling sliding window for the given sequence.\n   * It is a no-op if the model does not support sliding window.\n   * \\note Given this operation is tied with the underlying KV cache,\n   * we add the function in Model interface to expose this for Engine.\n   * This may be optimized with decoupling KV cache and Model in the future.\n   */\n  virtual void EnableSlidingWindowForSeq(int64_t seq_id) = 0;\n\n  /*! \\brief Prepare for the disaggregation KV data receive for the specified sequence and length.*/\n  virtual IntTuple DisaggPrepareKVRecv(int64_t seq_id, int length) = 0;\n\n  /*! \\brief Prepare for the disaggregation KV data send for the specified sequence and length.*/\n  virtual void DisaggMarkKVSend(int64_t seq_id, int begin_pos,\n                                IntTuple compressed_kv_append_metadata, int dst_group_offset) = 0;\n\n  /************** Raw Info Query **************/\n\n  /*! \\brief Return the metadata JSON object of the model. */\n  virtual ModelMetadata GetMetadata() const = 0;\n\n  /*! \\brief Get the number of available pages in KV cache. */\n  virtual int GetNumAvailablePages() const = 0;\n\n  /*! \\brief Get the current total sequence length in the KV cache. */\n  virtual int GetCurrentTotalSequenceLength() const = 0;\n\n  /*********************** Utilities  ***********************/\n\n  /*! \\brief Load the model's weight parameters, which is not loaded at construction time. */\n  virtual void LoadParams() = 0;\n\n  /*!\n   * \\brief Set the maximum number of sequences to be processed for the model,\n   * which is not initialized at construction time.\n   */\n  virtual void SetMaxNumSequence(int max_num_sequence) = 0;\n\n  /*!\n   * \\brief Set the prefill chunk size for the model,\n   * which is not initialized at construction time.\n   */\n  virtual void SetPrefillChunkSize(int prefill_chunk_size) = 0;\n\n  /*! \\brief Create a logit processor from this model. */\n  virtual LogitProcessor CreateLogitProcessor(int max_num_token,\n                                              Optional<EventTraceRecorder> trace_recorder) = 0;\n\n  /*! \\brief Create a sampler from this model. */\n  virtual Sampler CreateSampler(int max_num_sample, int num_models,\n                                Optional<EventTraceRecorder> trace_recorder) = 0;\n\n  /*!\n   * \\brief Estimate number of CPU units required to drive the model\n   * executing during TP.\n   * \\note This normally equals to the number of TP shards (or 0 if\n   * the model does not use TP) and can be used to hint runtime to\n   * avoid overuse cores in other places.\n   */\n  virtual int EstimateHostCPURequirement() const = 0;\n\n  /*! \\brief Get the sliding window size of the model. \"-1\" means sliding window is not enabled. */\n  virtual int GetSlidingWindowSize() const = 0;\n\n  /*! \\brief Get the attention sink size of the model. */\n  virtual int GetAttentionSinkSize() const = 0;\n\n  /*! \\brief Allocate an embedding tensor with the prefill chunk size. */\n  virtual ObjectRef AllocEmbeddingTensor() = 0;\n\n  /*! \\brief Allocate an hidden_states tensor with the prefill chunk size. */\n  virtual ObjectRef AllocHiddenStatesTensor() = 0;\n\n  /*! \\brief Reset the model KV cache and other metrics. */\n  virtual void Reset() = 0;\n\n  /*********************** Utilities for speculative decoding. ***********************/\n\n  virtual DraftTokenWorkspaceManager CreateDraftTokenWorkspaceManager(int max_num_token) = 0;\n\n  /*! \\brief Gather the hidden_states of the given indices and in-place update the dst tensor. */\n  virtual ObjectRef GatherHiddenStates(const ObjectRef& input, const std::vector<int>& indices,\n                                       ObjectRef* dst) = 0;\n\n  /*! \\brief Scatter the hidden_states of the given indices to the dst tensor. */\n  virtual void ScatterHiddenStates(const ObjectRef& input, const std::vector<int>& indices,\n                                   ObjectRef* dst) = 0;\n\n  /*! \\brief Gather the draft token probabilities of the given indices and in-place update the dst\n   * tensor. */\n  virtual Tensor GatherDraftProbs(const Tensor& input, const std::vector<int>& indices,\n                                  Tensor* dst) = 0;\n\n  /*! \\brief Scatter the draft token probabilities of the given indices to the dst tensor. */\n  virtual void ScatterDraftProbs(const Tensor& input, const std::vector<int>& indices,\n                                 Tensor* dst) = 0;\n\n  /************** Debug/Profile **************/\n\n  /*! \\brief Call the given global function on all workers. Only for debug purpose. */\n  virtual void DebugCallFuncOnAllAllWorker(const String& func_name, Optional<String> func_args) = 0;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<ModelObj>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO(\"mlc.serve.Model\", ModelObj, Object);\n};\n\nclass Model : public ObjectRef {\n public:\n  /*!\n   * \\brief Create the runtime module for LLM functions.\n   * \\param reload_lib_path The model library path.\n   * \\param model_path The path to the model weight parameters.\n   * \\param model_config The model config json object.\n   * \\param device The device to run the model on.\n   * \\param session The session to run the model on.\n   * \\param num_shards The number of tensor parallel shards of the model.\n   * \\param num_stages The number of pipeline parallel stages of the model.\n   * \\param trace_enabled A boolean indicating whether tracing is enabled.\n   * \\return The created runtime module.\n   */\n  static Model Create(String reload_lib_path, String model_path,\n                      const tvm::ffi::json::Object& model_config, DLDevice device,\n                      const Optional<Session>& session, int num_shards, int num_stages,\n                      bool trace_enabled);\n\n  /*!\n   * Load the model config from the given model path.\n   * \\param model_path The path to the model weight parameters.\n   * \\return The model config json object.\n   */\n  static Result<tvm::ffi::json::Object> LoadModelConfig(const String& model_path);\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Model, ObjectRef, ModelObj);\n};\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_MODEL_H_\n"
  },
  {
    "path": "cpp/serve/prefix_cache.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/prefix_cache.cc\n */\n#include \"prefix_cache.h\"\n\n#include <tvm/ffi/function.h>\n#include <tvm/runtime/nvtx.h>\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing namespace tvm::runtime;\n\nTVM_FFI_STATIC_INIT_BLOCK() { PrefixCacheObj::RegisterReflection(); }\n\n/*!\n * \\brief The implementation of prefix cache.\n */\nclass PrefixCacheImpl : public PrefixCacheObj {\n public:\n  /*!\n   * \\brief Constructor of paged radix tree.\n   * \\param max_num_recycling_seqs The maximum number of sequences in prefix cache.\n   * \\param remove_callback The optional callback function to call when removing a sequence.\n   */\n  explicit PrefixCacheImpl(size_t max_num_recycling_seqs, PrefixCacheRemoveCallback remove_callback)\n      : radix_tree_(PagedRadixTree::Create()),\n        max_num_recycling_seqs_(max_num_recycling_seqs),\n        remove_callback_(std::move(remove_callback)) {\n    recycling_seq_lrus_.clear();\n    reversed_recycling_seq_lrus_.clear();\n    seq_states_.clear();\n    seq_sliding_window_infos_.clear();\n    lru_counter_ = 0;\n  }\n\n  /*!\n   * \\brief Insert a new tokenized sequence into Prefix Cache.\n   * \\param seq_id The sequence ID.\n   * \\param tokens The tokens of tokenized sequence.\n   * \\param sliding_window_size The sliding window size for the sequence, -1 as sliding window\n   * disabled.\n   * \\param attention_sink_size The attention sink size for the sequence, 0 by default.\n   * \\return The matched result.\n   */\n  PrefixCacheMatchedResult InsertSequence(int64_t seq_id, std::vector<int32_t> tokens,\n                                          int sliding_window_size, int attention_sink_size) final {\n    TVM_FFI_ICHECK_NE(sliding_window_size, 0);\n    TVM_FFI_ICHECK_GE(attention_sink_size, 0);\n    TVM_FFI_ICHECK(seq_states_.find(seq_id) == seq_states_.end());\n    TVM_FFI_ICHECK(seq_sliding_window_infos_.find(seq_id) == seq_sliding_window_infos_.end());\n    TVM_FFI_ICHECK(!tokens.empty());\n    CommitSequenceExtention();\n    tokens.pop_back();\n    auto [matched_offset, matched_seqs] = radix_tree_->MatchPrefix(tokens);\n    std::pair<int, size_t> sliding_window_info{sliding_window_size, attention_sink_size};\n    // No prefix matched, directly adding new sequence.\n    if (!matched_offset) {\n      radix_tree_->AddSequence(seq_id);\n      seq_states_.emplace(seq_id, SequenceState::kActive);\n      seq_sliding_window_infos_.emplace(seq_id, sliding_window_info);\n      return PrefixCacheMatchedResult{0, -1, -1, 0};\n    }\n\n    TVM_FFI_ICHECK(!matched_seqs.empty());\n\n    // The reusage of recycling sequences logic is different between with/without sliding window\n    // enabled.\n    if (sliding_window_size != -1) {\n      // If sliding window enabled, the reusage of recycling sequences should be limited to exactly\n      // matched. And no rolling back is allowed due to the sliding window.\n      for (int64_t matched_seq_id : matched_seqs) {\n        if (seq_states_.at(matched_seq_id) == SequenceState::kRecycling &&\n            seq_sliding_window_infos_.at(matched_seq_id) == sliding_window_info) {\n          size_t matched_seq_length = radix_tree_->GetSequenceLength(matched_seq_id);\n          if (matched_seq_length == matched_offset) {\n            ReuseRecyclingSequence(matched_seq_id);\n            return PrefixCacheMatchedResult{matched_offset, -1, matched_seq_id, 0};\n          }\n        }\n      }\n    } else {\n      // If sliding window is not enabled, we can greedily reuse the shortest recycling sequence\n      // without sliding window, so that the loss or roll back of trailing tokens will be minimum.\n      size_t shortest_recycling_seq_length = 0;\n      int64_t shortest_recycling_seq_id = -1;\n\n      for (int64_t matched_seq_id : matched_seqs) {\n        if (seq_states_.at(matched_seq_id) == SequenceState::kRecycling &&\n            seq_sliding_window_infos_.at(matched_seq_id) == sliding_window_info) {\n          size_t matched_seq_length = radix_tree_->GetSequenceLength(matched_seq_id);\n          if (shortest_recycling_seq_id == -1 ||\n              matched_seq_length < shortest_recycling_seq_length) {\n            shortest_recycling_seq_id = matched_seq_id;\n            shortest_recycling_seq_length = matched_seq_length;\n          }\n        }\n      }\n      if (shortest_recycling_seq_id != -1 && matched_offset > shortest_recycling_seq_length * 0.9) {\n        ReuseRecyclingSequence(shortest_recycling_seq_id);\n        if (shortest_recycling_seq_length > matched_offset) {\n          // Recycling sequence is longer than new sequence, rolling back the redundant trailing\n          // tokens, to match the new sequence.\n          radix_tree_->RollBackSequence(shortest_recycling_seq_id,\n                                        shortest_recycling_seq_length - matched_offset);\n        }\n        return PrefixCacheMatchedResult{matched_offset, -1, shortest_recycling_seq_id,\n                                        shortest_recycling_seq_length - matched_offset};\n      }\n      // No reusage of recycling sequence, fallback to forking matched sequence. Currently, we only\n      // fork from sequence without sliding window, due to current paged KVCache implementation.\n      size_t longest_forking_offset = 0;\n      int64_t longest_forking_seq_id = -1;\n      for (int64_t matched_seq_id : matched_seqs) {\n        auto [matched_seq_sliding_window_size, matched_seq_attention_sink_size] =\n            seq_sliding_window_infos_.at(matched_seq_id);\n        if (matched_seq_sliding_window_size != -1) {\n          continue;\n        }\n        // If the matched is not enabled with sliding window, we can fork within matched offset\n        // tokens arbitrarily.\n        if (matched_offset > longest_forking_offset) {\n          longest_forking_offset = matched_offset;\n          longest_forking_seq_id = matched_seq_id;\n        }\n      }\n      if (longest_forking_offset > 0) {\n        radix_tree_->ForkSequence(seq_id, longest_forking_seq_id, longest_forking_offset);\n        seq_states_.emplace(seq_id, SequenceState::kActive);\n        seq_sliding_window_infos_.emplace(seq_id, sliding_window_info);\n        return PrefixCacheMatchedResult{longest_forking_offset, longest_forking_seq_id, -1, 0};\n      }\n    }\n    // No forking from matched sequence, fallback to adding new sequence.\n    radix_tree_->AddSequence(seq_id);\n    seq_states_.emplace(seq_id, SequenceState::kActive);\n    seq_sliding_window_infos_.emplace(seq_id, sliding_window_info);\n    return PrefixCacheMatchedResult{0, -1, -1, 0};\n  }\n\n  /*!\n   * \\brief Extend a sequence with new tokenized sequence suffix.\n   * \\param seq_id The sequence to be extended.\n   * \\param tokens The tokens of tokenized sequence suffix to extend.\n   * \\throw Error if the given sequence id is not valid or active.\n   */\n  void ExtendSequence(int64_t seq_id, const std::vector<int32_t>& tokens) final {\n    uncommitted_extended_token_ids_.emplace_back(seq_id, tokens);\n  }\n\n  void CommitSequenceExtention() final {\n    if (uncommitted_extended_token_ids_.empty()) {\n      return;\n    }\n    NVTXScopedRange nvtx_scope(\"PrefixCache commit sequence extension\");\n    for (const auto& [seq_id, uncommitted_token_ids] : uncommitted_extended_token_ids_) {\n      if (!HasSequence(seq_id)) {\n        // The sequence has been removed. Hence no action is needed.\n        continue;\n      }\n      const auto& it = seq_states_.find(seq_id);\n      TVM_FFI_ICHECK(it == seq_states_.end() || it->second == SequenceState::kActive);\n      radix_tree_->ExtendSequence(seq_id, uncommitted_token_ids);\n    }\n    uncommitted_extended_token_ids_.clear();\n  }\n\n  /*!\n   * \\brief Roll back a sequence by number of tokens.\n   * \\param seq_id The sequence ID for index.\n   * \\param num_tokens The number of tokens to be rolled back.\n   * \\throw Error if the given sequence id is not valid or active.\n   */\n  void RollBackSequence(int64_t seq_id, size_t num_tokens) final {\n    CommitSequenceExtention();\n    TVM_FFI_ICHECK(seq_states_.at(seq_id) == SequenceState::kActive);\n    radix_tree_->RollBackSequence(seq_id, num_tokens);\n  }\n\n  /*!\n   * \\brief Recycle a sequence. The recycled sequence will not be removed immediately, as long as\n   * memory is sufficient and the number of sequence in prefix cache belows the maximum number of\n   * sequence. And it will be reused again in the future request.\n   * \\param seq_id The sequence to be recycled.\n   * \\param lazy The flag if the sequence should be removed lazily or intermediary.\n   * \\throw Error if the given sequence id is not valid.\n   */\n  void RecycleSequence(int64_t seq_id, bool lazy = true) final {\n    CommitSequenceExtention();\n    TVM_FFI_ICHECK(seq_states_.at(seq_id) == SequenceState::kActive);\n    TVM_FFI_ICHECK(recycling_seq_lrus_.find(seq_id) == recycling_seq_lrus_.end());\n    if (lazy && max_num_recycling_seqs_ != 0) {\n      // Remove the sequence lazily.\n      if (recycling_seq_lrus_.size() == max_num_recycling_seqs_) {\n        // If prefix cache has reached maximum number of recycling sequences, try to pop one\n        // recycling sequence.\n        TVM_FFI_ICHECK(TryFreeMemory());\n        TVM_FFI_ICHECK_EQ(recycling_seq_lrus_.size(), max_num_recycling_seqs_ - 1);\n      }\n      seq_states_.at(seq_id) = SequenceState::kRecycling;\n      ++lru_counter_;\n      recycling_seq_lrus_.emplace(seq_id, lru_counter_);\n      reversed_recycling_seq_lrus_.emplace(lru_counter_, seq_id);\n    } else {\n      // Remove the sequence intermediately.\n      radix_tree_->RemoveSequence(seq_id);\n      if (remove_callback_ != nullptr) {\n        remove_callback_(seq_id);\n      }\n      TVM_FFI_ICHECK(seq_states_.erase(seq_id));\n      TVM_FFI_ICHECK(seq_sliding_window_infos_.erase(seq_id));\n    }\n  }\n\n  /*!\n   * \\brief Try to remove recycling sequence to free up memory. It will remove the oldest recycling\n   sequence.\n   * \\return The flag if there is a sequence removed. In other word, return true when memory is\n   freed successfully.\n   * \\throw Error if the given sequence id is not valid.\n   */\n  bool TryFreeMemory() final {\n    NVTXScopedRange nvtx_scope(\"PrefixCache TryFreeMemory\");\n    if (reversed_recycling_seq_lrus_.empty()) {\n      // There is no recycling sequence. No memory can be freed.\n      return false;\n    }\n    auto [lru, seq_id] = *reversed_recycling_seq_lrus_.begin();\n    TVM_FFI_ICHECK(seq_states_.at(seq_id) == SequenceState::kRecycling);\n    TVM_FFI_ICHECK_EQ(recycling_seq_lrus_.at(seq_id), lru);\n    radix_tree_->RemoveSequence(seq_id);\n    if (remove_callback_ != nullptr) {\n      remove_callback_(seq_id);\n    }\n    TVM_FFI_ICHECK(seq_states_.erase(seq_id));\n    TVM_FFI_ICHECK(recycling_seq_lrus_.erase(seq_id));\n    TVM_FFI_ICHECK(reversed_recycling_seq_lrus_.erase(lru));\n    TVM_FFI_ICHECK(seq_sliding_window_infos_.erase(seq_id));\n    return true;\n  }\n\n  /*!\n   * \\brief Check if a sequence exists.\n   * \\param seq_id The sequence ID for index.\n   * \\return The sequence existence.\n   * \\throw Error if sequence ID is not valid.\n   */\n  bool HasSequence(int64_t seq_id) final { return radix_tree_->HasSequence(seq_id); }\n\n  /*!\n   * \\brief Reset the prefix cache to initial status.\n   */\n  void Reset() final {\n    radix_tree_->Reset();\n    recycling_seq_lrus_.clear();\n    reversed_recycling_seq_lrus_.clear();\n    seq_states_.clear();\n    seq_sliding_window_infos_.clear();\n    uncommitted_extended_token_ids_.clear();\n    lru_counter_ = 0;\n  }\n\n  PrefixCacheMode Mode() final { return PrefixCacheMode::kRadix; }\n\n private:\n  void ReuseRecyclingSequence(int64_t seq_id) {\n    TVM_FFI_ICHECK(seq_states_.at(seq_id) == SequenceState::kRecycling);\n    size_t lru = recycling_seq_lrus_.at(seq_id);\n    TVM_FFI_ICHECK_EQ(reversed_recycling_seq_lrus_.at(lru), seq_id);\n    seq_states_.at(seq_id) = SequenceState::kActive;\n    TVM_FFI_ICHECK(recycling_seq_lrus_.erase(seq_id));\n    TVM_FFI_ICHECK(reversed_recycling_seq_lrus_.erase(lru));\n  }\n\n  /*!\n   * \\brief The sequence states.\n   */\n  enum class SequenceState : int {\n    /*!\n     * \\brief The state of active sequence. In this state, the sequence can be forked only. When\n     * recycling a sequence, it will transfer to kRecycling.\n     */\n    kActive = 0,\n    /*!\n     * \\brief The state of recycling sequence. In this state, the sequence can be forked or be\n     * reused. And it will transfer to kActive only when reused.\n     */\n    kRecycling = 1,\n  };\n  /*!\n   * \\brief The core data structure radix tree.\n   */\n  PagedRadixTree radix_tree_;\n  /*!\n   * \\brief The map from sequence to LRU time stamps.\n   */\n  std::unordered_map<int64_t, size_t> recycling_seq_lrus_;\n  /*!\n   * \\brief The map from LRU time stamps to sequence, used to find the sequence with earliest LRU\n   * time stamp.\n   */\n  std::unordered_map<size_t, int64_t> reversed_recycling_seq_lrus_;\n  /*!\n   * \\brief The maximum number of recycling sequences in prefix cache. Set -1 as infinite prefix\n   * cache.\n   */\n  int max_num_recycling_seqs_ = -1;\n  /*!\n   * \\brief The LRU counter.\n   */\n  size_t lru_counter_ = 0;\n  /*!\n   * \\brief The callback function to call when removing a sequence. This can be used to\n   * removing sequence in KVCache and return sequence ID to ID manager lazily\n   */\n  PrefixCacheRemoveCallback remove_callback_ = nullptr;\n  /*!\n   * \\brief The map from sequence to its sequence states.\n   */\n  std::unordered_map<int64_t, SequenceState> seq_states_;\n  /*!\n   * \\brief The map from sequence to its sliding window information. The sliding window information\n   * is a pair of sliding window size and attention sink size. The sliding window size is -1 for\n   * sliding window disabled, or positive for sliding window size. The attention sink size is\n   * non-negative and used when sliding window size is positive.\n   */\n  std::unordered_map<int64_t, std::pair<int, size_t>> seq_sliding_window_infos_;\n  /*!\n   * \\brief The collection of uncommitted extended token ids of sequences.\n   * The \"ExtendSequence\" method only lazily add token ids into this collection,\n   * and these uncommitted token ids will be committed when needed.\n   *\n   * Note: Since the tokens stored are references, CommitSequenceExtention should be called after\n   * each action, to avoid the uncaught changes of uncomitted extended token ids.\n   */\n  std::vector<std::pair<int64_t, const std::vector<int32_t>&>> uncommitted_extended_token_ids_;\n};  // namespace serve\n\n/*!\n * \\brief The implementation of no prefix cache.\n */\nclass NoPrefixCache : public PrefixCacheObj {\n public:\n  /*!\n   * \\brief Insert a new tokenized sequence into Prefix Cache.\n   * \\param seq_id The sequence ID.\n   * \\param tokens The tokens of tokenized sequence.\n   * \\param sliding_window_size The sliding window size for the sequence, -1 as sliding window\n   * disabled.\n   * \\param attention_sink_size The attention sink size for the sequence, 0 by default.\n   * \\return The matched result.\n   */\n  PrefixCacheMatchedResult InsertSequence(int64_t seq_id, std::vector<int32_t> tokens,\n                                          int sliding_window_size, int attention_sink_size) final {\n    // Since there is no prefix cache, always return as new sequence.\n    return PrefixCacheMatchedResult{0, -1, -1, 0};\n  }\n\n  /*!\n   * \\brief Extend a sequence with new tokenized sequence suffix.\n   * \\param seq_id The sequence to be extended.\n   * \\param tokens The tokens of tokenized sequence suffix to extend.\n   * \\throw Error if called since this should never be called.\n   */\n  void ExtendSequence(int64_t seq_id, const std::vector<int32_t>& tokens) final {\n    // No-op;\n  }\n\n  void CommitSequenceExtention() final {\n    // No-op;\n  }\n\n  /*!\n   * \\brief Roll back a sequence by number of tokens.\n   * \\param seq_id The sequence ID for index.\n   * \\param num_tokens The number of tokens to be rolled back.\n   * \\throw Error if called since this should never be called.\n   */\n  void RollBackSequence(int64_t seq_id, size_t num_tokens) final {\n    // Since there is no prefix cache, this method should never be called.\n    LOG(FATAL) << \"Unreachable code.\";\n  }\n\n  /*!\n   * \\brief Recycle a sequence. The recycled sequence will not be removed immediately, as long as\n   * memory is sufficient and the number of sequence in prefix cache belows the maximum number of\n   * sequence. And it will be reused again in the future request.\n   * \\param seq_id The sequence to be recycled.\n   * \\param lazy The flag if the sequence should be removed lazily or intermediary.\n   * \\throw Error if the given sequence id is not valid.\n   */\n  void RecycleSequence(int64_t seq_id, bool lazy = true) final {\n    // Since there is no prefix cache, this method should never be called.\n    LOG(FATAL) << \"Unreachable code.\";\n  }\n\n  /*!\n   * \\brief Try to remove recycling sequence to free up memory. It will remove the oldest\n   recycling sequence.\n   * \\return Always return false as no sequence stored.\n   */\n  bool TryFreeMemory() final {\n    // Since there is no prefix cache, always return false.\n    return false;\n  }\n\n  /*!\n   * \\brief Check if a sequence exists.\n   * \\param seq_id The sequence ID for index.\n   * \\return Always return false as no sequence stored.\n   */\n  bool HasSequence(int64_t seq_id) final {\n    // Since there is no prefix cache, always return false.\n    return false;\n  }\n\n  /*!\n   * \\brief Reset the prefix cache to initial status. Do nothing and return.\n   */\n  void Reset() final {}\n\n  PrefixCacheMode Mode() final { return PrefixCacheMode::kDisable; }\n};\n\nPrefixCache PrefixCache::CreateRadixPrefixCache(size_t max_num_recycling_seqs,\n                                                PrefixCacheRemoveCallback remove_callback) {\n  ObjectPtr<PrefixCacheImpl> n =\n      tvm::ffi::make_object<PrefixCacheImpl>(max_num_recycling_seqs, std::move(remove_callback));\n  return PrefixCache(std::move(n));\n}\n\nPrefixCache PrefixCache::CreateNoPrefixCache() {\n  ObjectPtr<NoPrefixCache> n = tvm::ffi::make_object<NoPrefixCache>();\n  return PrefixCache(std::move(n));\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/prefix_cache.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/prefix_cache.h\n */\n#ifndef MLC_LLM_SERVE_PREFIX_CACHE_H_\n#define MLC_LLM_SERVE_PREFIX_CACHE_H_\n#include <tvm/ffi/container/shape.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/runtime/object.h>\n\n#include <functional>\n#include <optional>\n#include <unordered_map>\n#include <unordered_set>\n\n#include \"model.h\"\n#include \"radix_tree.h\"\n#include \"request_state.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing namespace tvm::runtime;\n\n/*!\n * \\brief The signature of callback removing function.\n */\nusing PrefixCacheRemoveCallback = std::function<void(int64_t)>;\n\n/*!\n * \\brief The matched result from prefix cache. This result describes how to pre-process the new\n * sequence, to leverage the existing data in KVCache by reusing past sequences or forking from\n * other sequences.\n */\nclass PrefixCacheMatchedResult {\n public:\n  /*!\n   * \\brief The matched and prefilled prefix offset.\n   */\n  size_t prefilled_offset = 0;\n  /*!\n   * \\brief The sequence ID to fork from.\n   */\n  int64_t forked_seq_id = -1;\n  /*!\n   * \\brief The finished sequence ID to reuse.\n   */\n  int64_t reused_seq_id = -1;\n  /*!\n   * \\brief The number of tailing tokens to be popped from the reused sequence.\n   */\n  size_t reused_seq_pop_last_tokens = 0;\n};\n\nclass PrefixCacheObj : public Object {\n public:\n  /*!\n   * \\brief Insert a new tokenized sequence into Prefix Cache.\n   * \\param seq_id The sequence ID.\n   * \\param tokens The tokens of tokenized sequence.\n   * \\param sliding_window_size The sliding window size for the sequence, -1 as sliding window\n   * disabled.\n   * \\param attention_sink_size The attention sink size for the sequence, 0 by default.\n   * \\return The matched result.\n   */\n  virtual PrefixCacheMatchedResult InsertSequence(int64_t seq_id, std::vector<int32_t> tokens,\n                                                  int sliding_window_size = -1,\n                                                  int attention_sink_size = 0) = 0;\n\n  /*!\n   * \\brief Extend a sequence with new tokenized sequence suffix.\n   * This extension might be cached and lazily committed later.\n   * \\param seq_id The sequence to be extended.\n   * \\param tokens The tokens of tokenized sequence suffix to extend.\n   * \\throw Error if the given sequence id is not valid or active.\n   */\n  virtual void ExtendSequence(int64_t seq_id, const std::vector<int32_t>& tokens) = 0;\n\n  /*! \\brief Commit the cached sequence extension from \"ExtendSequence\". */\n  virtual void CommitSequenceExtention() = 0;\n\n  /*!\n   * \\brief Roll back a sequence by number of tokens.\n   * \\param seq_id The sequence ID for index.\n   * \\param num_tokens The number of tokens to be rolled back.\n   * \\throw Error if the given sequence id is not valid or active.\n   */\n  virtual void RollBackSequence(int64_t seq_id, size_t num_tokens) = 0;\n\n  /*!\n   * \\brief Recycle a sequence. The recycled sequence will not be removed immediately, as long as\n   * memory is sufficient and the number of sequence in prefix cache belows the maximum number of\n   * sequence. And it will be reused again in the future request.\n   * \\param seq_id The sequence to be recycled.\n   * \\param lazy The flag if the sequence should be removed lazily or intermediary.\n   * \\throw Error if the given sequence id is not valid.\n   */\n  virtual void RecycleSequence(int64_t seq_id, bool lazy = true) = 0;\n\n  /*!\n   * \\brief Try to remove recycling sequence to free up memory. It will remove the oldest recycling\n   sequence.\n   * \\return The flag if there is a sequence removed. In other word, return true when memory is\n   freed successfully.\n   * \\throw Error if the given sequence id is not valid.\n   */\n  virtual bool TryFreeMemory() = 0;\n\n  /*!\n   * \\brief Check if a sequence exists.\n   * \\param seq_id The sequence ID for index.\n   * \\return The sequence existence.\n   * \\throw Error if sequence ID is not valid.\n   */\n  virtual bool HasSequence(int64_t seq_id) = 0;\n\n  /*!\n   * \\brief Reset the prefix cache to initial status.\n   */\n  virtual void Reset() = 0;\n\n  /*! \\brief Return the prefix cache mode. */\n  virtual PrefixCacheMode Mode() = 0;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<PrefixCacheObj>();\n  }\n\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO(\"mlc.serve.PrefixCache\", PrefixCacheObj, Object);\n};\n\nclass PrefixCache : public ObjectRef {\n public:\n  /*!\n   * \\brief Initialization of prefix cache.\n   * \\param max_recycling_seqs The maximum number of recycling sequences in prefix cache.\n   * \\param remove_callback The optional callback function to call when removing a sequence.\n   */\n  static PrefixCache CreateRadixPrefixCache(size_t max_recycling_seqs,\n                                            PrefixCacheRemoveCallback remove_callback = nullptr);\n  /*!\n   * \\brief Initialization of no prefix cache.\n   */\n  static PrefixCache CreateNoPrefixCache();\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PrefixCache, ObjectRef, PrefixCacheObj);\n};\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_PREFIX_CACHE_H_\n"
  },
  {
    "path": "cpp/serve/radix_tree.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/radix_tree.cc\n */\n#include \"radix_tree.h\"\n\n#include <tvm/ffi/function.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/runtime/logging.h>\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing namespace tvm::runtime;\n\nTVM_FFI_STATIC_INIT_BLOCK() { PagedRadixTreeObj::RegisterReflection(); }\n\n/*!\n * \\brief The sequence ID linked list structure in paged radix tree node.\n */\nstruct SequenceIDNode {\n  /*! \\brief The stored sequence ID. */\n  int64_t id = 0;\n  /*! \\brief The pointer to the next sequence ID. */\n  SequenceIDNode* next = nullptr;\n};\n\n/*!\n * \\brief The sequence ID node pool.\n *\n * The sequence ID node pool allocates a block of sequence ID nodes when pool is full,\n * and frees all when destruction, to avoid frequent memory operation.\n */\nclass SequenceIDNodePool {\n public:\n  /*! \\brief The constructor of sequence ID node pool, allocating a new sequence ID node block. */\n  SequenceIDNodePool() {\n    NewNodeBlock_();\n    used_nodes_.clear();\n  }\n\n  /*!\n   * \\brief Get a sequence ID node from pool, and assign the fields.\n   * If there is no available node, it will allocate a new sequence ID node block.\n   * \\param seq_id The assigned sequence ID of allocated sequence ID node.\n   * \\param node The next sequence ID node pointer of allocated sequence ID node.\n   * \\return The allocated radix page.\n   */\n  SequenceIDNode* Allocate(int64_t seq_id, SequenceIDNode* next) {\n    if (free_node_indices_.empty()) {\n      NewNodeBlock_();\n      TVM_FFI_ICHECK(!free_node_indices_.empty());\n    }\n    size_t id = free_node_indices_.back();\n    free_node_indices_.pop_back();\n    SequenceIDNode* node = nodes_[id];\n    used_nodes_[node] = id;\n    node->id = seq_id;\n    node->next = next;\n    return node;\n  }\n\n  /*!\n   * \\brief Free a sequence ID node to pool.\n   * \\param node The sequence ID node to free.\n   */\n  void Free(SequenceIDNode* node) {\n    TVM_FFI_ICHECK(used_nodes_.find(node) != used_nodes_.end());\n    free_node_indices_.push_back(used_nodes_[node]);\n    used_nodes_.erase(node);\n  }\n\n  /*!\n   * \\brief Reset the sequence ID node pool to initial status.\n   */\n  void Reset() {\n    used_nodes_.clear();\n    free_node_indices_.reserve(nodes_.size());\n    for (size_t i = 0; i < nodes_.size(); ++i) {\n      nodes_[i]->id = 0;\n      nodes_[i]->next = nullptr;\n      free_node_indices_[i] = i;\n    }\n  }\n\n  /*! \\brief The destructor of sequence ID node pool, freeing memory for each node. */\n  ~SequenceIDNodePool() {\n    for (SequenceIDNode* node_block : node_blocks_) {\n      delete[] node_block;\n    }\n  }\n\n private:\n  /*! \\brief The size of each node pool block. */\n  static constexpr size_t kNodeBlockSize_ = 64;\n  /*! \\brief The raw sequence ID node block pool, each element is a sequence ID node array. */\n  std::vector<SequenceIDNode*> node_blocks_;\n  /*! \\brief The sequence ID node pool, each element is a sequence ID node pointer. */\n  std::vector<SequenceIDNode*> nodes_;\n  /*! \\brief The indices of free sequence ID node in node pool. */\n  std::vector<size_t> free_node_indices_;\n  /*! \\brief The map from used paged sequence ID node to its index in node pool. */\n  std::unordered_map<SequenceIDNode*, size_t> used_nodes_;\n\n  /*! \\brief Allocate a new node pool block. */\n  void NewNodeBlock_() {\n    size_t node_id_offset = node_blocks_.size() * kNodeBlockSize_;\n    node_blocks_.push_back(new SequenceIDNode[kNodeBlockSize_]);\n    nodes_.reserve(nodes_.size() + kNodeBlockSize_);\n    free_node_indices_.reserve(free_node_indices_.size() + kNodeBlockSize_);\n    for (size_t i = 0; i < kNodeBlockSize_; ++i) {\n      nodes_.push_back(&node_blocks_.back()[i]);\n      free_node_indices_.push_back(i + node_id_offset);\n    }\n  }\n};\n\n/*!\n * \\brief The paged radix tree node data structure.\n *\n * The paged radix tree node is similar to original radix tree node, but with the limited length for\n * prefix in page, so that the memory usage in each page is the same and is fixed once allocated.\n * Since the page only consists of pointers and int tokens, the page memory layout is int array\n * indeed. The lower offset is the pointers and page information, while the higher offset is the\n * stored prefix tokens.\n *\n * And since the vocabulary size may be very large, the paged Radix tree is represented\n * as left-child, right-sibling binary tree.\n *\n * Also, due to possible pop/push front/back tokens in page, the page is designed as circular\n * buffer, to make full use of each page.\n *\n * Each page records the sequence exactly ends with the prefix tokens stored in page. In other word,\n * all sequences locate in the boundary of each page, or the end of each page.\n */\nstruct RadixPage {\n  /*! \\brief The parent page. */\n  RadixPage* parent;\n  /*! \\brief The first child page. */\n  RadixPage* first_child;\n  /*! \\brief The sibling page sharing the same parent page. */\n  RadixPage* next_sibling;\n  /*! \\brief The head of sequence ID linked list. */\n  SequenceIDNode* seq_ids;\n  /*! \\brief The capacity of maximum stored prefix tokens. */\n  size_t capacity;\n  /*! \\brief The start offset of stored prefix tokens. The legal value is of [0, capacity). */\n  size_t offset;\n  /*! \\brief The length of stored prefix tokens. The legal value is of [0, capacity). */\n  size_t length;\n  /*! \\brief The offset of first prefix token in memory layout. */\n  static constexpr int kDataOffset = (sizeof(RadixPage*) * 3 + sizeof(SequenceIDNode*) +\n                                      sizeof(size_t) * 3 + sizeof(int32_t) - 1) /\n                                     sizeof(int32_t);\n\n  /*!\n   * \\brief Overload operator [] to get the prefix tokens by index as simple int array.\n   * \\param i The prefix token index.\n   * \\return The value of i-th prefix token.\n   */\n  int32_t& operator[](size_t i) {\n    return reinterpret_cast<int32_t*>(this)[kDataOffset + (i + offset) % capacity];\n  }\n\n  /*!\n   * \\brief Extend or push back a suffix tokens in page.\n   * \\param suffix The suffix tokens array.\n   * \\param suffix_length The suffix length to extend.\n   * \\throw Error if suffix length is larger than current vacant space.\n   */\n  void Extend(const int32_t* suffix, size_t suffix_length) {\n    TVM_FFI_ICHECK_LE(suffix_length + length, capacity);\n    for (int i = 0; i < suffix_length; ++i) {\n      (*this)[i + length] = suffix[i];\n    }\n    length += suffix_length;\n  }\n\n  /*!\n   * \\brief Add a sequence ID in page.\n   * \\param pool The sequence ID node pool to allocate new node.\n   * \\param id The sequence ID to add.\n   */\n  void AddSequence(SequenceIDNodePool* pool, int64_t id) { seq_ids = pool->Allocate(id, seq_ids); }\n\n  /*!\n   * \\brief Pop a sequence ID in page.\n   * \\param pool The sequence ID node pool to free popped node.\n   * \\param id The sequence ID to pop.\n   * \\throw Error if no such sequence ID in page.\n   */\n  void PopSequence(SequenceIDNodePool* pool, int64_t id) {\n    if (seq_ids->id == id) {\n      // If the popped sequence ID is the first node in linked list,\n      // directly skip from head and free it.\n      SequenceIDNode* next = seq_ids->next;\n      pool->Free(seq_ids);\n      seq_ids = next;\n    } else {\n      // If the popped sequence ID is not the first node in linked list,\n      // skip it from previous node and free it.\n      SequenceIDNode* last = seq_ids;\n      SequenceIDNode* cur = seq_ids->next;\n      while (cur) {\n        if (cur->id == id) {\n          last->next = cur->next;\n          pool->Free(cur);\n          return;\n        }\n        last = cur;\n        cur = cur->next;\n      }\n      LOG(FATAL) << \"Sequence ID = \" << id << \" not found.\";\n    }\n  }\n\n  /*!\n   * \\brief Get all sequence ID in page.\n   * \\return The std::vector of sequence ID in page.\n   */\n  std::vector<int64_t> GetLocalSequence() {\n    std::vector<int64_t> output;\n    for (SequenceIDNode* node = seq_ids; node; node = node->next) {\n      output.push_back(node->id);\n    }\n    return output;\n  }\n\n  /*!\n   * \\brief Get any sequence ID in current page or child pages.\n   * Since there is always a sequence in leaf pages, it only check first child if no sequence ID in\n   * current page.\n   * \\return The any sequence ID in current page or child pages.\n   */\n  int32_t FindAnyChildSequence() {\n    if (seq_ids) return seq_ids->id;\n    return first_child->FindAnyChildSequence();\n  }\n\n  /*!\n   * \\brief Get all sequence ID in current page and child pages, using Iterate method with lambda\n   * expression as callback to avoid frequently memory allocation of std::vector.\n   * \\return The std::vector of all sequence ID in current page and child pages.\n   */\n  std::vector<int64_t> FindAllChildSequence() {\n    std::vector<int64_t> output = GetLocalSequence();\n    if (first_child) {\n      first_child->Iterate([&output](const RadixPage* page) {\n        for (SequenceIDNode* node = page->seq_ids; node; node = node->next) {\n          output.push_back(node->id);\n        }\n      });\n    }\n    return output;\n  }\n\n  /*!\n   * \\brief The iteration method for tree or sub-tree traverse.\n   * \\param f The callback function to invoke at each radix page visited.\n   */\n  template <class CallbackFunc>\n  void Iterate(CallbackFunc f) {\n    f(this);\n    if (next_sibling) next_sibling->Iterate(f);\n    if (first_child) first_child->Iterate(f);\n  }\n\n  /*!\n   * \\brief Get the last sibling of current page.\n   * \\return The page whose next_sibling is current page, or nullptr if current is the first_child\n   * of its parent page.\n   */\n  RadixPage* GetLastSibling() {\n    if (parent == nullptr) return nullptr;\n    if (parent->first_child == this) return nullptr;\n    for (RadixPage* child = parent->first_child; child; child = child->next_sibling) {\n      if (child->next_sibling == this) return child;\n    }\n    return nullptr;\n  }\n\n  /*!\n   * \\brief Find the child indexed by first token.\n   * \\return The child page started with first token, or nullptr if no such child page.\n   */\n  RadixPage* FindChild(int64_t first_token) {\n    int32_t casted = first_token;\n    // Iterate all child radix pages, as the child radix pages are stored unorderly.\n    for (RadixPage* child = first_child; child; child = child->next_sibling) {\n      if ((*child)[0] == casted) return child;\n    }\n    return nullptr;\n  }\n\n  /*! \\brief Insert a new child page. */\n  void InsertChild(RadixPage* child) {\n    child->parent = this;\n    child->next_sibling = first_child;\n    first_child = child;\n  }\n\n  /*!\n   * \\brief Remove a child page.\n   * \\throw Error if page to be removed is not child page.\n   */\n  void RemoveChild(RadixPage* child) {\n    TVM_FFI_ICHECK(child->parent == this);\n    if (first_child == child) {\n      first_child = child->next_sibling;\n    } else {\n      child->GetLastSibling()->next_sibling = child->next_sibling;\n    }\n  }\n\n  /*!\n   * \\brief Check current page is mergable with its child page.\n   * The page is mergable if and only if\n   * 1. No sequence ID in current page, as sequence ID is not allowed to exist within page.\n   * 2. The current page has child page.\n   * 3. The current page has only one child page.\n   * 4. The current page prefix and the child page prefix can be concatenated into one page.\n   * \\return True if current page is mergable, or false.\n   */\n  bool Mergeable() {\n    if (seq_ids) return false;\n    if (!first_child) return false;\n    if (first_child->next_sibling) return false;\n    if (length + first_child->length > capacity) return false;\n    return true;\n  }\n\n  /*!\n   * \\brief Match the given prefix within page.\n   * \\param prefix The prefix token array.\n   * \\param prefix_length The length of prefix token array.\n   * \\return The matched prefix offset within page, or the first mismatched token position. The\n   * possible return value is [0, page->length], where page->length means the page is completely the\n   * prefix of given prefix.\n   */\n  size_t MatchPrefix(const int32_t* prefix, size_t prefix_length) {\n    size_t n = std::min(length, prefix_length);\n    for (int i = 0; i < n; ++i) {\n      if ((*this)[i] != prefix[i]) return i;\n    }\n    return n;\n  }\n};\n\n/*!\n * \\brief The paged radix tree page pool.\n *\n * The paged radix tree page pool allocates a block of radix tree pages when pool is full,\n * and frees all when destruction, to avoid frequent memory operation.\n */\nclass RadixPagePool {\n public:\n  /*! \\brief The constructor of paged radix tree page pool, allocating memory for each page. */\n  RadixPagePool() {\n    NewPageBlock_();\n    used_pages_.clear();\n  }\n\n  /*!\n   * \\brief Get a radix page from pool.\n   * If there is no available page, it will allocate a new radix page block.\n   * \\return The allocated radix page.\n   */\n  RadixPage* Allocate() {\n    if (free_page_indices_.empty()) {\n      NewPageBlock_();\n      TVM_FFI_ICHECK(!free_page_indices_.empty());\n    }\n    int id = free_page_indices_.back();\n    free_page_indices_.pop_back();\n    RadixPage* page = pages_[id];\n    used_pages_[page] = id;\n    page->parent = page->first_child = page->next_sibling = nullptr;\n    page->capacity = kPageCapacity_;\n    page->offset = page->length = 0;\n    page->seq_ids = nullptr;\n    return page;\n  }\n\n  /*!\n   * \\brief Free a radix page to pool.\n   * \\param page The radix page to free.\n   */\n  void Free(RadixPage* page) {\n    TVM_FFI_ICHECK_EQ(page->seq_ids, nullptr);\n    TVM_FFI_ICHECK(used_pages_.find(page) != used_pages_.end());\n    free_page_indices_.push_back(used_pages_[page]);\n    TVM_FFI_ICHECK(used_pages_.erase(page));\n  }\n\n  /*!\n   * \\brief Get the token capacity of free pages.\n   * \\return The the token capacity of free pages.\n   */\n  size_t FreeCapacity() { return free_page_indices_.size() * kPageCapacity_; }\n\n  /*!\n   * \\brief Reset the paged radix tree page pool to initial status.\n   */\n  void Reset() {\n    used_pages_.clear();\n    free_page_indices_.reserve(pages_.size());\n    for (int i = 0; i < pages_.size(); ++i) {\n      pages_[i]->parent = pages_[i]->first_child = pages_[i]->next_sibling = nullptr;\n      pages_[i]->capacity = kPageCapacity_;\n      pages_[i]->offset = pages_[i]->length = 0;\n      pages_[i]->seq_ids = nullptr;\n      free_page_indices_[i] = i;\n    }\n  }\n\n  /*! \\brief The destructor of paged radix tree page pool, freeing memory for each page. */\n  ~RadixPagePool() {\n    for (int32_t* page_block : page_blocks_) {\n      delete[] page_block;\n    }\n  }\n\n private:\n  /*! \\brief The size of each page pool block. */\n  static constexpr size_t kPageBlockSize_ = 64;\n  /*! \\brief The page capacity of each paged radix tree page. */\n  static constexpr size_t kPageCapacity_ = 64;\n  /*! \\brief The page size of each paged radix tree page. */\n  static constexpr size_t kPageSize_ = kPageCapacity_ + RadixPage::kDataOffset;\n  /*! \\brief The raw paged radix tree page block pool,\n  each element is a raw paged radix tree page array. */\n  std::vector<int32_t*> page_blocks_;\n  /*! \\brief The paged radix tree page pool,\n  each element is a raw paged radix tree page pointer. */\n  std::vector<RadixPage*> pages_;\n  /*! \\brief The indices of free paged radix page in page pool. */\n  std::vector<size_t> free_page_indices_;\n  /*! \\brief The map from used paged radix tree page to its index in page pool. */\n  std::unordered_map<RadixPage*, size_t> used_pages_;\n\n  /*! \\brief Allocate a new page pool block. */\n  void NewPageBlock_() {\n    size_t page_id_offset = page_blocks_.size() * kPageBlockSize_;\n    page_blocks_.push_back(new int32_t[kPageBlockSize_ * kPageSize_]);\n    pages_.reserve(pages_.size() + kPageBlockSize_);\n    free_page_indices_.reserve(free_page_indices_.size() + kPageBlockSize_);\n    for (size_t i = 0; i < kPageBlockSize_; ++i) {\n      pages_.push_back(reinterpret_cast<RadixPage*>(page_blocks_.back() + i * kPageSize_));\n      free_page_indices_.push_back(i + page_id_offset);\n    }\n  }\n};\n\n// PagedRadixTree\n\n/*!\n * \\brief The paged radix tree data structure.\n */\nclass PagedRadixTreeImpl : public PagedRadixTreeObj {\n public:\n  /*! \\brief The map from sequence to paged radix tree node it is stored. */\n  std::unordered_map<int32_t, RadixPage*> seq2page;\n  /*! \\brief The sequence ID node pool. */\n  SequenceIDNodePool* seq_id_node_pool = nullptr;\n  /*! \\brief The radix page pool. */\n  RadixPagePool* radix_page_pool = nullptr;\n  /*! \\brief The root page of paged radix tree. */\n  RadixPage* root = nullptr;\n\n  explicit PagedRadixTreeImpl() {\n    seq_id_node_pool = new SequenceIDNodePool();\n    radix_page_pool = new RadixPagePool();\n\n    root = reinterpret_cast<RadixPage*>(new int32_t[RadixPage::kDataOffset]);\n    root->parent = root->first_child = root->next_sibling = nullptr;\n    root->offset = root->length = root->capacity = 0;\n    root->seq_ids = nullptr;\n  }\n\n  /*!\n   * \\brief Check if a sequence exists.\n   * \\param seq_id The sequence ID for index.\n   * \\return The sequence existence.\n   * \\throw Error if sequence ID is not valid.\n   */\n  bool HasSequence(int64_t seq_id) { return seq2page.find(seq_id) != seq2page.end(); }\n\n  /*!\n   * \\brief Get a sequence's all tokens.\n   * \\param seq_id The sequence ID for index.\n   * \\return The sequence tokens.\n   * \\throw Error if sequence ID is not valid.\n   */\n  IntTuple GetSequence(int64_t seq_id) {\n    TVM_FFI_ICHECK(seq2page.find(seq_id) != seq2page.end());\n    size_t length = GetSequenceLength(seq_id);\n    std::vector<int64_t> output(length);\n    size_t offset = length;\n    for (RadixPage* page = seq2page[seq_id]; page; page = page->parent) {\n      offset -= page->length;\n      for (int i = 0; i < page->length; ++i) {\n        output[offset + i] = (*page)[i];\n      }\n    }\n    return IntTuple(output);\n  }\n\n  /*!\n   * \\brief Get all sequences with longest common prefix with give prefix tokens.\n   * \\param tokens The prefix tokens for reference.\n   * \\return The pair of matched prefix length and the array of matched sequences indices.\n   */\n  std::pair<size_t, std::vector<int64_t>> MatchPrefix(const std::vector<int32_t>& tokens) {\n    const int32_t* prefix = tokens.data();\n    size_t length = tokens.size();\n    auto [page, offset, in_page_offset] = MatchSequence(root, prefix, length);\n    if (!offset) return std::make_pair(0, std::vector<int64_t>());\n    return std::make_pair(offset, page->FindAllChildSequence());\n  }\n\n  /*!\n   * \\brief Get a sequence's length.\n   * \\param seq_id The sequence ID for index.\n   * \\return The sequence length.\n   * \\throw Error if sequence ID is not valid.\n   */\n  size_t GetSequenceLength(int64_t seq_id) {\n    TVM_FFI_ICHECK(seq2page.find(seq_id) != seq2page.end());\n    size_t length = 0;\n    for (RadixPage* page = seq2page[seq_id]; page; page = page->parent) {\n      length += page->length;\n    }\n    return length;\n  }\n\n  /*!\n   * \\brief Fork a sequence from parent sequence at given position.\n   * \\param seq_id The new sequence ID.\n   * \\param parent_seq_id The parent sequence ID to fork from.\n   * \\param forked_offset The position of parent sequence to fork at.\n   * The valid value is [1, length of forked sequence]. If the position equals the length of forked\n   * sequence, the new sequence will copy the entire forked sequence.\n   * \\throw Error if sequence ID or\n   * forked postion is not valid.\n   */\n  void ForkSequence(int64_t seq_id, int64_t parent_seq_id, size_t forked_offset) {\n    TVM_FFI_ICHECK(seq2page.find(seq_id) == seq2page.end());\n    TVM_FFI_ICHECK(seq2page.find(parent_seq_id) != seq2page.end());\n    TVM_FFI_ICHECK_GT(forked_offset, 0);\n    size_t length = GetSequenceLength(parent_seq_id);\n    TVM_FFI_ICHECK_LE(forked_offset, length);\n    for (RadixPage* page = seq2page[parent_seq_id]; page; page = page->parent) {\n      if (forked_offset > length - page->length) {\n        if (forked_offset < length) {\n          // Split radix page if forked position is within page\n          page = SplitPage(page, forked_offset + page->length - length);\n        }\n        page->AddSequence(seq_id_node_pool, seq_id);\n        seq2page[seq_id] = page;\n        return;\n      }\n      length -= page->length;\n    }\n  }\n\n  /*!\n   * \\brief Add an empty sequence at root.\n   * \\param seq_id The new sequence ID.\n   * \\throw Error if sequence ID is not valid.\n   */\n  void AddSequence(int64_t seq_id) {\n    TVM_FFI_ICHECK(seq2page.find(seq_id) == seq2page.end())\n        << \"Sequence ID = \" << seq_id << \" has been added.\";\n    root->AddSequence(seq_id_node_pool, seq_id);\n    seq2page[seq_id] = root;\n  }\n\n  /*!\n   * \\brief Extend a sequence with given tokens.\n   * \\param seq_id The sequence ID for index.\n   * \\param tokens The given tokens to extend.\n   * \\throw Error if sequence ID is not valid.\n   */\n  void ExtendSequence(int64_t seq_id, const std::vector<int32_t>& tokens) {\n    TVM_FFI_ICHECK(seq2page.find(seq_id) != seq2page.end());\n    const int32_t* suffix = tokens.data();\n    size_t length = tokens.size();\n    RadixPage* original_page = seq2page[seq_id];\n    original_page->PopSequence(seq_id_node_pool, seq_id);\n    auto [page, offset, in_page_offset] = MatchSequence(original_page, suffix, length);\n    if (in_page_offset < page->length) {\n      // Split page if extended sequence mismatches within page\n      page = SplitPage(page, in_page_offset);\n    }\n    if (offset < length && !page->seq_ids && !page->first_child && page->capacity > page->length) {\n      // Extend in the existing leaf page first if possible.\n      size_t suffix_length = std::min(page->capacity - page->length, length - offset);\n      page->Extend(suffix + offset, suffix_length);\n      offset += suffix_length;\n    }\n    while (offset < length) {\n      // Allocate new radix page and extend tokens\n      RadixPage* new_page = radix_page_pool->Allocate();\n      page->InsertChild(new_page);\n      page = new_page;\n      size_t suffix_length = std::min(page->capacity - page->length, length - offset);\n      page->Extend(suffix + offset, suffix_length);\n      offset += suffix_length;\n    }\n    page->AddSequence(seq_id_node_pool, seq_id);\n    seq2page[seq_id] = page;\n    if (original_page->Mergeable()) {\n      // The original page may be mergeable, as the sequence ID changes\n      MergePage(original_page);\n    }\n  }\n\n  /*!\n   * \\brief Roll back a sequence by number of tokens.\n   * \\param seq_id The sequence ID for index.\n   * \\param num_tokens The number of tokens to be rolled back.\n   * \\throw Error if sequence ID is not valid.\n   */\n  void RollBackSequence(int64_t seq_id, size_t num_tokens) {\n    size_t length = GetSequenceLength(seq_id);\n    TVM_FFI_ICHECK_GT(num_tokens, 0);\n    TVM_FFI_ICHECK_LE(num_tokens, length);\n    if (num_tokens == length) {\n      // If rolling back whole sequence, just remove the sequence and add it again equivalently.\n      RemoveSequence(seq_id);\n      AddSequence(seq_id);\n      return;\n    }\n    RadixPage* page = seq2page[seq_id];\n    // Remove the sequence temporarily, but keeping the data and starting rolling back.\n    page->PopSequence(seq_id_node_pool, seq_id);\n    seq2page.erase(seq_id);\n    while (page->length <= num_tokens) {\n      // Roll back entire page\n      num_tokens -= page->length;\n      RadixPage* parent = page->parent;\n      if (page->seq_ids == nullptr && page->first_child == nullptr) {\n        // The leaf page is removable\n        parent->RemoveChild(page);\n        radix_page_pool->Free(page);\n      }\n      page = parent;\n    }\n    if (page->seq_ids == nullptr && page->first_child == nullptr) {\n      // The page is leaf page, directly roll back in page length\n      page->length -= num_tokens;\n      // Update the mapping from sequence to page\n      page->AddSequence(seq_id_node_pool, seq_id);\n      seq2page[seq_id] = page;\n      return;\n    }\n    // Split page for rolled back sequence\n    if (num_tokens) {\n      page = SplitPage(page, page->length - num_tokens);\n    }\n    // Update the mapping from sequence to page\n    page->AddSequence(seq_id_node_pool, seq_id);\n    seq2page[seq_id] = page;\n  }\n\n  /*!\n   * \\brief Remove a sequence.\n   * \\param seq_id The sequence ID to remove.\n   * \\throw Error if sequence ID is not valid.\n   */\n  void RemoveSequence(int64_t seq_id) {\n    RadixPage* page = seq2page[seq_id];\n    page->PopSequence(seq_id_node_pool, seq_id);\n    seq2page.erase(seq_id);\n    while (page->parent && !page->seq_ids && !page->first_child) {\n      RadixPage* parent = page->parent;\n      parent->RemoveChild(page);\n      radix_page_pool->Free(page);\n      page = parent;\n    }\n    if (page && page->Mergeable()) {\n      // The remaining page may be mergeable, as the sequence ID changes\n      MergePage(page);\n    }\n  }\n\n  /*!\n   * \\brief Get the remaining token capacity of the paged radix tree.\n   * \\return The the remaining token capacity of the paged radix tree.\n   */\n  size_t FreeCapacity() { return radix_page_pool->FreeCapacity(); }\n\n  void Reset() {\n    radix_page_pool->Reset();\n    seq_id_node_pool->Reset();\n    seq2page.clear();\n    root->parent = root->first_child = root->next_sibling = nullptr;\n    root->offset = root->length = root->capacity = 0;\n    root->seq_ids = nullptr;\n  }\n\n  /*! \\brief The destructor to free root page. */\n  ~PagedRadixTreeImpl() {\n    delete[] reinterpret_cast<int32_t*>(root);\n    delete seq_id_node_pool;\n    delete radix_page_pool;\n  }\n\n private:\n  /*!\n   * \\brief Merge a radix tree page with its child radix tree page, to save radix tree page.\n   * e.g. MergePage([1, 2, _, _, _] -> [3, 4, 5, _, _]) = [1, 2, 3, 4, 5].\n   * And the page to be merged should be page->Mergeable().\n   * \\param page The parent radix tree page.\n   */\n  void MergePage(RadixPage* page) {\n    TVM_FFI_ICHECK(page->Mergeable());\n    RadixPage* child = page->first_child;\n    for (int i = 0; i < child->length; ++i) {\n      (*page)[i + page->length] = (*child)[i];\n    }\n    page->length += child->length;\n    page->first_child = child->first_child;\n    for (RadixPage* p = child->first_child; p; p = p->next_sibling) {\n      p->parent = page;\n    }\n    page->seq_ids = child->seq_ids;\n    std::vector<int64_t> seq_ids = page->GetLocalSequence();\n    for (int64_t id : seq_ids) seq2page[id] = page;\n    child->seq_ids = nullptr;\n    radix_page_pool->Free(child);\n  }\n\n  /*!\n   * \\brief Split a radix tree page at given position, to accept new sequence.\n   * e.g. SplitPage([1, 2, 3, 4, 5], 2) = [1, 2, _, _, _] -> [3, 4, 5, _, _].\n   * \\param page The radix tree page to split.\n   * \\param offset The position to split the radix tree page.\n   * \\return The splitted radix tree page. It can be different from the input radix tree page, as\n   * there may be implicit radix tree page merge.\n   */\n  RadixPage* SplitPage(RadixPage* page, size_t offset) {\n    TVM_FFI_ICHECK_LT(offset, page->length);\n    RadixPage* child = radix_page_pool->Allocate();\n    child->parent = page;\n    child->first_child = page->first_child;\n    for (RadixPage* p = page->first_child; p; p = p->next_sibling) {\n      p->parent = child;\n    }\n    page->first_child = child;\n    for (int i = offset; i < page->length; ++i) {\n      (*child)[i - offset] = (*page)[i];\n    }\n    child->length = page->length - offset;\n    page->length = offset;\n    child->seq_ids = page->seq_ids;\n    std::vector<int64_t> seq_ids = page->GetLocalSequence();\n    for (int64_t id : seq_ids) seq2page[id] = child;\n    page->seq_ids = nullptr;\n    if (child->Mergeable()) {\n      // The child page may be mergeable\n      MergePage(child);\n    }\n    if (page->parent && page->parent->Mergeable()) {\n      // The parent page may be mergeable\n      page = page->parent;\n      MergePage(page);\n    }\n    return page;\n  }\n\n  /*!\n   * \\brief Match with given token from a radix tree page, stopping at first mismatch.\n   * \\param page The radix tree page to start matching.\n   * \\param tokens The given tokens to match.\n   * \\param length The length of given tokens.\n   */\n  std::tuple<RadixPage*, size_t, size_t> MatchSequence(RadixPage* page, const int32_t* tokens,\n                                                       size_t length) {\n    size_t offset = 0;\n    while (offset < length) {\n      if (RadixPage* child = page->FindChild(tokens[offset])) {\n        // If child page starts with offset-th token, common prefix at least ends with child page\n        size_t matched_offset = child->MatchPrefix(tokens + offset, length - offset);\n        offset += matched_offset;\n        if (matched_offset < child->length) {\n          // Common prefix ends within child page\n          return std::make_tuple(child, offset, matched_offset);\n        }\n        page = child;\n      } else {\n        // No child page starts with offset-th token, common prefix ends with current page\n        return std::make_tuple(page, offset, page->length);\n      }\n    }\n    return std::make_tuple(page, length, page->length);\n  }\n};\n\nPagedRadixTree PagedRadixTree::Create() {\n  return PagedRadixTree(tvm::ffi::make_object<PagedRadixTreeImpl>());\n}\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  namespace refl = tvm::ffi::reflection;\n  refl::GlobalDef()\n      .def(\"mlc.serve.PagedRadixTree\", []() { return PagedRadixTree::Create(); })\n      .def(\"mlc.serve.PagedRadixTreeMatchPrefix\",\n           [](PagedRadixTree paged_radix_tree, IntTuple tokens) {\n             std::vector<int32_t> token_ids{tokens.begin(), tokens.end()};\n             auto [offset, seq_ids] = paged_radix_tree->MatchPrefix(token_ids);\n             seq_ids.insert(seq_ids.begin(), offset);\n             return IntTuple(seq_ids);\n           })\n      .def(\"mlc.serve.PagedRadixTreeExtendSequence\",\n           [](PagedRadixTree paged_radix_tree, int64_t seq_id, IntTuple tokens) {\n             std::vector<int32_t> token_ids{tokens.begin(), tokens.end()};\n             paged_radix_tree->ExtendSequence(seq_id, std::move(token_ids));\n           })\n      .def(\"mlc.serve.PagedRadixTreeRollBackSequence\",\n           [](PagedRadixTree paged_radix_tree, int64_t seq_id, int64_t num_tokens) {\n             paged_radix_tree->RollBackSequence(seq_id, num_tokens);\n           })\n      .def(\"mlc.serve.PagedRadixTreeForkSequence\",\n           [](PagedRadixTree paged_radix_tree, int64_t seq_id, int64_t parent_seq_id,\n              uint64_t forked_offset) {\n             paged_radix_tree->ForkSequence(seq_id, parent_seq_id, forked_offset);\n           })\n      .def_method(\"mlc.serve.PagedRadixTreeHasSequence\", &PagedRadixTreeObj::HasSequence)\n      .def_method(\"mlc.serve.PagedRadixTreeAddSequence\", &PagedRadixTreeObj::AddSequence)\n      .def_method(\"mlc.serve.PagedRadixTreeRemoveSequence\", &PagedRadixTreeObj::RemoveSequence)\n      .def_method(\"mlc.serve.PagedRadixTreeGetSequence\", &PagedRadixTreeObj::GetSequence)\n      .def(\"mlc.serve.PagedRadixTreeGetSequenceLength\",\n           [](PagedRadixTree paged_radix_tree, int64_t seq_id) {\n             return static_cast<int64_t>(paged_radix_tree->GetSequenceLength(seq_id));\n           })\n      .def(\"mlc.serve.PagedRadixTreeFreeCapacity\", [](PagedRadixTree paged_radix_tree) {\n        return static_cast<int64_t>(paged_radix_tree->FreeCapacity());\n      });\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/radix_tree.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/radix_tree.h\n */\n#ifndef MLC_LLM_SERVE_RADIX_TREE_H_\n#define MLC_LLM_SERVE_RADIX_TREE_H_\n#include <tvm/ffi/container/shape.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/runtime/int_tuple.h>\n#include <tvm/runtime/object.h>\n\n#include <unordered_map>\n#include <unordered_set>\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing namespace tvm::runtime;\n\n/*!\n * \\brief The paged radix tree data structure.\n */\nclass PagedRadixTreeObj : public Object {\n public:\n  /*!\n   * \\brief Check if a sequence exists.\n   * \\param seq_id The sequence ID for index.\n   * \\return The sequence existence.\n   * \\throw Error if sequence ID is not valid.\n   */\n  virtual bool HasSequence(int64_t seq_id) = 0;\n\n  /*!\n   * \\brief Get a sequence's all tokens.\n   * \\param seq_id The sequence ID for index.\n   * \\return The sequence tokens.\n   * \\throw Error if sequence ID is not valid.\n   */\n  virtual IntTuple GetSequence(int64_t seq_id) = 0;\n\n  /*!\n   * \\brief Get all sequences with longest common prefix with give prefix tokens.\n   * \\param tokens The prefix tokens for reference.\n   * \\return The pair of matched prefix length and the array of matched sequences indices.\n   */\n  virtual std::pair<size_t, std::vector<int64_t>> MatchPrefix(\n      const std::vector<int32_t>& tokens) = 0;\n\n  /*!\n   * \\brief Get a sequence's length.\n   * \\param seq_id The sequence ID for index.\n   * \\return The sequence length.\n   * \\throw Error if sequence ID is not valid.\n   */\n  virtual size_t GetSequenceLength(int64_t seq_id) = 0;\n\n  /*!\n   * \\brief Fork a sequence from parent sequence at given position.\n   * \\param seq_id The new sequence ID.\n   * \\param parent_seq_id The parent sequence ID to fork from.\n   * \\param forked_offset The position of parent sequence to fork at.\n   * The valid value is [1, length of forked sequence]. If the position equals the length of forked\n   * sequence, the new sequence will copy the entire forked sequence.\n   * \\throw Error if sequence ID or\n   * forked postion is not valid.\n   */\n  virtual void ForkSequence(int64_t seq_id, int64_t parent_seq_id, size_t forked_offset) = 0;\n\n  /*!\n   * \\brief Add an empty sequence at root.\n   * \\param seq_id The new sequence ID.\n   * \\throw Error if sequence ID is not valid.\n   */\n  virtual void AddSequence(int64_t seq_id) = 0;\n\n  /*!\n   * \\brief Extend a sequence with given tokens.\n   * \\param seq_id The sequence ID for index.\n   * \\param tokens The given tokens to extend.\n   * \\throw Error if sequence ID is not valid.\n   */\n  virtual void ExtendSequence(int64_t seq_id, const std::vector<int32_t>& tokens) = 0;\n\n  /*!\n   * \\brief Roll back a sequence by number of tokens.\n   * \\param seq_id The sequence ID for index.\n   * \\param num_tokens The number of tokens to be rolled back.\n   * \\throw Error if sequence ID is not valid.\n   */\n  virtual void RollBackSequence(int64_t seq_id, size_t num_tokens) = 0;\n\n  /*!\n   * \\brief Remove a sequence.\n   * \\param seq_id The sequence ID to remove.\n   * \\throw Error if sequence ID is not valid.\n   */\n  virtual void RemoveSequence(int64_t seq_id) = 0;\n\n  /*!\n   * \\brief Get the remaining token capacity of the paged radix tree.\n   * \\return The the remaining token capacity of the paged radix tree.\n   */\n  virtual size_t FreeCapacity() = 0;\n\n  /*!\n   * \\brief Reset the paged radix tree to initial status.\n   */\n  virtual void Reset() = 0;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<PagedRadixTreeObj>();\n  }\n\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO(\"mlc.serve.PagedRadixTree\", PagedRadixTreeObj, Object);\n};\n\nclass PagedRadixTree : public ObjectRef {\n public:\n  /*!\n   * \\brief Construct a paged radix tree.\n   * \\return The constructed paged radix tree.   */\n  static PagedRadixTree Create();\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(PagedRadixTree, ObjectRef, PagedRadixTreeObj);\n};\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_RADIX_TREE_H_\n"
  },
  {
    "path": "cpp/serve/request.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/request.cc\n */\n\n#include \"request.h\"\n\n#include <tvm/ffi/function.h>\n#include <tvm/ffi/reflection/registry.h>\n\n#include \"data.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\n/****************** Request ******************/\n\nTVM_FFI_STATIC_INIT_BLOCK() { RequestNode::RegisterReflection(); }\n\nRequest::Request(String id, Array<Data> inputs, GenerationConfig generation_cfg) {\n  if (generation_cfg->debug_config.special_request == SpecialRequestKind::kNone) {\n    TVM_FFI_ICHECK(!inputs.empty()) << \"No input data is given.\";\n  }\n  // Compute the total input length, or fall back to \"-1\" which means\n  // unknown due to the existence of untokenized data.\n  int prompt_tokens = 0;\n  for (Data input : inputs) {\n    if (const auto* token_data = input.as<TokenDataNode>()) {\n      prompt_tokens += token_data->token_ids.size();\n    } else if (const auto* image_data = input.as<ImageDataNode>()) {\n      prompt_tokens += image_data->GetLength();\n    } else {\n      prompt_tokens = -1;\n      break;\n    }\n  }\n\n  ObjectPtr<RequestNode> n = tvm::ffi::make_object<RequestNode>();\n  n->id = std::move(id);\n  n->inputs = std::move(inputs);\n  n->prompt_tokens = prompt_tokens;\n  n->generation_cfg = std::move(generation_cfg);\n  data_ = std::move(n);\n}\n\nRequest Request::FromUntokenized(const Request& request, const Tokenizer& tokenizer) {\n  bool has_untokenized_input = false;\n  Array<Data> inputs;\n  inputs.reserve(request->inputs.size());\n  // Tokenize all text inputs.\n  for (Data input : request->inputs) {\n    if (const auto* text_data = input.as<TextDataNode>()) {\n      has_untokenized_input = true;\n      std::vector<int> token_ids = tokenizer->Encode(text_data->text);\n      inputs.push_back(TokenData(token_ids));\n    } else {\n      inputs.push_back(input);\n    }\n  }\n\n  // If there is no untokenized input, we don't need to create a new request.\n  if (!has_untokenized_input) {\n    TVM_FFI_ICHECK_NE(request->prompt_tokens, -1);\n    return request;\n  } else {\n    return Request(request->id, std::move(inputs), request->generation_cfg);\n  }\n}\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  namespace refl = tvm::ffi::reflection;\n  refl::GlobalDef()\n      .def(\"mlc.serve.RequestGetInputs\", [](Request request) { return request->inputs; })\n      .def(\"mlc.serve.RequestGetGenerationConfigJSON\", [](Request request) {\n        return tvm::ffi::json::Stringify(request->generation_cfg->AsJSON());\n      });\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/request.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/request.h\n * \\brief Implementation of llm chat.\n */\n#ifndef MLC_LLM_SERVE_REQUEST_H_\n#define MLC_LLM_SERVE_REQUEST_H_\n\n#include <tvm/ffi/container/array.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/ffi/string.h>\n#include <tvm/runtime/object.h>\n\n#include \"../tokenizers/tokenizers.h\"\n#include \"config.h\"\n#include \"data.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing namespace tvm::runtime;\n\n/****************** Request ******************/\n\n/*!\n * \\brief The user submitted text-generation request, which contains\n * a unique request id, a list of multi-modal inputs, a set of\n * generation configuration parameters.\n * \\note Request is immutable and can be re-dispatched to another\n * node and restart the request handling on the new one.\n */\nclass RequestNode : public Object {\n public:\n  /*!\n   * \\brief The unique identifier of the request.\n   * Different requests should have different ids.\n   */\n  String id;\n  /*!\n   * \\brief The user inputs of a request. Input may have multi-modality.\n   * \\sa data.h\n   */\n  Array<Data> inputs;\n  /*!\n   * \\brief The equivalent input sequence length of the request.\n   * \"-1\" means the input length is unknown due to the existence\n   * of untokenized text data.\n   */\n  int prompt_tokens = -1;\n  /*!\n   * \\brief The sampling configuration which may contain temperature,\n   * top_p, repetition_penalty, max_gen_len, etc.\n   */\n  GenerationConfig generation_cfg;\n  /*! \\brief Backward reference to the request state. */\n  Object* rstate = nullptr;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<RequestNode>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO(\"mlc.serve.Request\", RequestNode, Object);\n};\n\nclass Request : public ObjectRef {\n public:\n  explicit Request(String id, Array<Data> inputs, GenerationConfig generation_cfg);\n\n  /*!\n   * \\brief Return a request object with all text data tokenized,\n   * and the request ID kept the same as the input one.\n   * \\param request The request to be tokenized.\n   * \\param tokenizer The tokenizer to tokenize the input data of the given request.\n   * \\return The request object whose data are tokenized.\n   */\n  static Request FromUntokenized(const Request& request, const Tokenizer& tokenizer);\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Request, ObjectRef, RequestNode);\n};\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_REQUEST_H_\n"
  },
  {
    "path": "cpp/serve/request_state.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/request_state.cc\n */\n\n#include \"request_state.h\"\n\n#include <unordered_set>\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  RequestModelStateNode::RegisterReflection();\n  RequestStateEntryNode::RegisterReflection();\n  RequestStateNode::RegisterReflection();\n}\n\n/****************** RequestModelState ******************/\n\nRequestModelState::RequestModelState(\n    Request request, int model_id, int64_t internal_id, Array<Data> inputs,\n    const std::optional<xgrammar::CompiledGrammar>& compiled_grammar) {\n  ObjectPtr<RequestModelStateNode> n = tvm::ffi::make_object<RequestModelStateNode>();\n  n->model_id = model_id;\n  n->internal_id = internal_id;\n  n->inputs = std::move(inputs);\n\n  if (compiled_grammar.has_value()) {\n    // TODO(yixin): set rollback limit to a configurable value.\n    n->grammar_matcher =\n        xgrammar::GrammarMatcher(compiled_grammar.value(), std::nullopt, false, std::nullopt, 10);\n  }\n\n  n->request = std::move(request);\n  data_ = std::move(n);\n}\n\nint RequestModelStateNode::GetInputLength() const {\n  int total_length = 0;\n  for (Data input : inputs) {\n    total_length += input->GetLength();\n  }\n  return total_length;\n}\n\nbool RequestModelStateNode::RequireNextTokenBitmask() { return grammar_matcher.has_value(); }\n\nvoid RequestModelStateNode::GetNextTokenBitmask(DLTensor* bitmask) {\n  TVM_FFI_ICHECK(grammar_matcher.has_value());\n\n  grammar_matcher->GetNextTokenBitmask(bitmask);\n}\n\nvoid RequestModelStateNode::CommitToken(SampleResult sampled_token) {\n  committed_tokens.push_back(std::move(sampled_token));\n  appeared_token_ids[sampled_token.GetTokenId()] += 1;\n  // There will be one more token that will be processed in the next decoding.\n  ++num_tokens_for_next_decode;\n\n  // Update the grammar matcher state if it exists.\n  if (grammar_matcher) {\n    bool accepted = grammar_matcher->AcceptToken(sampled_token.GetTokenId());\n    TVM_FFI_ICHECK(accepted) << \"Token id \" << sampled_token.GetTokenId()\n                             << \" is not accepted by the grammar state matcher.\";\n  }\n}\n\nvoid RequestModelStateNode::RollbackTokens(int count) {\n  TVM_FFI_ICHECK(count <= static_cast<int>(committed_tokens.size()));\n  for (int i = 0; i < count; ++i) {\n    auto it = appeared_token_ids.find(committed_tokens.back().GetTokenId());\n    TVM_FFI_ICHECK(it != appeared_token_ids.end());\n    if (--it->second == 0) {\n      appeared_token_ids.erase(it);\n    }\n    committed_tokens.pop_back();\n    if (grammar_matcher) {\n      grammar_matcher->Rollback(1);\n    }\n  }\n}\n\nvoid RequestModelStateNode::AddDraftToken(SampleResult sampled_token, int draft_token_slot,\n                                          int64_t parent_idx) {\n  draft_output_tokens.push_back(std::move(sampled_token));\n  draft_token_slots.push_back(draft_token_slot);\n  draft_token_parent_idx.push_back(parent_idx);\n  draft_token_first_child_idx.push_back(-1);\n  if (parent_idx != -1) {\n    if (draft_token_first_child_idx[parent_idx] == -1) {\n      draft_token_first_child_idx[parent_idx] = static_cast<int>(draft_output_tokens.size()) - 1;\n    }\n  }\n}\n\nvoid RequestModelStateNode::RemoveAllDraftTokens(std::vector<int>* removed_draft_token_slots) {\n  if (removed_draft_token_slots != nullptr) {\n    std::unordered_set<int> dedup;\n    removed_draft_token_slots->clear();\n    for (auto slot : draft_token_slots) {\n      bool inserted = dedup.insert(slot).second;\n      if (inserted) {\n        removed_draft_token_slots->push_back(slot);\n      }\n    }\n  }\n  draft_token_slots.clear();\n  draft_token_parent_idx.clear();\n  draft_token_first_child_idx.clear();\n  draft_output_tokens.clear();\n}\n\n/****************** RequestActionPostProcWorkspace ******************/\n\nRequestStreamOutput RequestActionPostProcWorkspace::GetStreamOutput() {\n  for (const RequestStreamOutput& stream_output : stream_outputs) {\n    if (stream_output->unpacked) {\n      return stream_output;\n    }\n  }\n\n  TVM_FFI_ICHECK(!stream_outputs.empty());\n  int num_response = stream_outputs[0]->group_delta_token_ids.size();\n  std::vector<std::vector<int64_t>> group_delta_token_ids;\n  std::vector<std::vector<String>> group_delta_logprob_json_strs;\n  std::vector<Optional<String>> group_finish_reason;\n  std::vector<String> group_extra_prefix_string;\n  group_delta_token_ids.resize(num_response);\n  group_finish_reason.resize(num_response);\n  group_extra_prefix_string.resize(num_response);\n  if (stream_outputs[0]->group_delta_logprob_json_strs.has_value()) {\n    group_delta_logprob_json_strs.resize(num_response);\n  }\n  RequestStreamOutput stream_output(stream_outputs[0]->request_id, std::move(group_delta_token_ids),\n                                    stream_outputs[0]->group_delta_logprob_json_strs.has_value()\n                                        ? std::make_optional(group_delta_logprob_json_strs)\n                                        : std::nullopt,\n                                    std::move(group_finish_reason),\n                                    std::move(group_extra_prefix_string));\n  stream_outputs.push_back(stream_output);\n  return stream_output;\n}\n\n/****************** RequestStateEntry ******************/\n\nRequestStateEntry::RequestStateEntry(\n    Request request, int num_models, int64_t internal_id, int rng_seed,\n    const std::vector<std::string>& token_table,\n    const std::optional<xgrammar::CompiledGrammar>& compiled_grammar, int parent_idx) {\n  ObjectPtr<RequestStateEntryNode> n = tvm::ffi::make_object<RequestStateEntryNode>();\n  Array<RequestModelState> mstates;\n  Array<Data> inputs;\n  if (parent_idx == -1) {\n    inputs = request->inputs;\n  }\n  mstates.reserve(num_models);\n  for (int i = 0; i < num_models; ++i) {\n    mstates.push_back(RequestModelState(request, i, internal_id, inputs, compiled_grammar));\n  }\n  n->status = RequestStateStatus::kPending;\n  n->rng = RandomGenerator(rng_seed);\n  n->stop_str_handler = StopStrHandler(!request->generation_cfg->debug_config.ignore_eos\n                                           ? request->generation_cfg->stop_strs\n                                           : Array<String>(),\n                                       token_table);\n  n->request = std::move(request);\n  n->parent_idx = parent_idx;\n  n->mstates = std::move(mstates);\n  n->next_callback_token_pos = 0;\n  data_ = std::move(n);\n}\n\nvoid RequestStateEntryNode::GetDeltaRequestReturn(const Tokenizer& tokenizer,\n                                                  int64_t max_single_sequence_length,\n                                                  RequestStreamOutput* delta_stream_output,\n                                                  int idx) {\n  TVM_FFI_ICHECK_NOTNULL(delta_stream_output);\n  bool needs_logprobs = (*delta_stream_output)->group_delta_logprob_json_strs.has_value();\n  (*delta_stream_output)->group_delta_token_ids[idx].clear();\n  if (needs_logprobs) {\n    (*delta_stream_output)->group_delta_logprob_json_strs.value()[idx].clear();\n  }\n  (*delta_stream_output)->group_finish_reason[idx] = std::nullopt;\n  (*delta_stream_output)->group_extra_prefix_string[idx] = this->extra_prefix_string;\n  this->extra_prefix_string.clear();\n\n  const std::vector<SampleResult>& committed_tokens = this->mstates[0]->committed_tokens;\n  int num_committed_tokens = committed_tokens.size();\n  TVM_FFI_ICHECK_LE(this->next_callback_token_pos, num_committed_tokens);\n\n  // Case 1. There is no new token ids.\n  if (this->next_callback_token_pos == num_committed_tokens && extra_prefix_string.empty()) {\n    return;\n  }\n\n  // Case 2. Any of the stop strings is matched.\n  TVM_FFI_ICHECK(!stop_str_handler->StopTriggered());\n  while (next_callback_token_pos < num_committed_tokens) {\n    stop_str_handler->Put(committed_tokens[next_callback_token_pos].GetTokenId(),\n                          &(*delta_stream_output)->group_delta_token_ids[idx]);\n    if (needs_logprobs) {\n      (*delta_stream_output)\n          ->group_delta_logprob_json_strs.value()[idx]\n          .push_back(committed_tokens[next_callback_token_pos].GetLogProbJSON(\n              tokenizer, request->generation_cfg->logprobs));\n    }\n    ++next_callback_token_pos;\n    if (stop_str_handler->StopTriggered()) {\n      (*delta_stream_output)->group_finish_reason[idx] = \"stop\";\n      break;\n    }\n  }\n\n  // Case 3. Any of the stop tokens appears in the committed tokens ===> Finished\n  // `stop_token_ids` includes the stop tokens from conversation template and user-provided tokens.\n  // This check will be ignored when `ignore_eos` is set for the benchmarking purpose.\n  if (!request->generation_cfg->debug_config.ignore_eos) {\n    for (int i = 0; i < static_cast<int>((*delta_stream_output)->group_delta_token_ids[idx].size());\n         ++i) {\n      if (std::any_of(request->generation_cfg->stop_token_ids.begin(),\n                      request->generation_cfg->stop_token_ids.end(),\n                      [delta_stream_output, idx, i](int32_t token) {\n                        return token == (*delta_stream_output)->group_delta_token_ids[idx][i];\n                      })) {\n        // Stop token matched. Erase the stop token and all tokens after it.\n        (*delta_stream_output)->group_finish_reason[idx] = \"stop\";\n        while (static_cast<int>((*delta_stream_output)->group_delta_token_ids[idx].size()) > i) {\n          (*delta_stream_output)->group_delta_token_ids[idx].pop_back();\n        }\n        break;\n      }\n    }\n  }\n\n  // Case 4. When stop token is not detected (e.g. ignore_eos is set), but the grammar state is\n  // terminated, stop the generation and pop the last token (used to trigger the termination).\n  if ((*delta_stream_output)->group_finish_reason[idx] != \"stop\" &&\n      this->mstates[0]->grammar_matcher.has_value() &&\n      this->mstates[0]->grammar_matcher->IsTerminated()) {\n    (*delta_stream_output)->group_delta_token_ids[idx].pop_back();\n    (*delta_stream_output)->group_finish_reason[idx] = \"stop\";\n  }\n\n  if ((*delta_stream_output)->group_finish_reason[idx].has_value()) {\n    return;\n  }\n\n  // Case 5. Generation reaches the specified max generation length ==> Finished\n  // `max_tokens` means the generation length is limited by model capacity.\n  if (request->generation_cfg->max_tokens >= 0 &&\n      num_committed_tokens >= request->generation_cfg->max_tokens) {\n    stop_str_handler->Finish(&(*delta_stream_output)->group_delta_token_ids[idx]);\n    (*delta_stream_output)->group_finish_reason[idx] = \"length\";\n    return;\n  }\n  // Case 6. Total length of the request reaches the maximum single sequence length ==> Finished\n  if (request->prompt_tokens + num_committed_tokens >= max_single_sequence_length) {\n    stop_str_handler->Finish(&(*delta_stream_output)->group_delta_token_ids[idx]);\n    (*delta_stream_output)->group_finish_reason[idx] = \"length\";\n  }\n}\n\n/****************** RequestState ******************/\n\nRequestState::RequestState(std::vector<RequestStateEntry> entries, int num_response,\n                           std::chrono::high_resolution_clock::time_point add_time_point) {\n  TVM_FFI_ICHECK(!entries.empty());\n  ObjectPtr<RequestStateNode> n = tvm::ffi::make_object<RequestStateNode>();\n  n->entries = std::move(entries);\n  n->metrics.prompt_tokens = n->entries[0]->request->prompt_tokens;\n  n->metrics.add_time_point = add_time_point;\n\n  std::vector<std::vector<int64_t>> group_delta_token_ids;\n  std::vector<std::vector<String>> group_delta_logprob_json_strs;\n  std::vector<Optional<String>> group_finish_reason;\n  std::vector<String> group_extra_prefix_string;\n  group_delta_token_ids.resize(num_response);\n  group_finish_reason.resize(num_response);\n  group_extra_prefix_string.resize(num_response);\n  if (n->entries[0]->request->generation_cfg->logprobs) {\n    group_delta_logprob_json_strs.resize(num_response);\n  }\n  RequestStreamOutput stream_output(n->entries[0]->request->id, std::move(group_delta_token_ids),\n                                    n->entries[0]->request->generation_cfg->logprobs\n                                        ? std::make_optional(group_delta_logprob_json_strs)\n                                        : std::nullopt,\n                                    std::move(group_finish_reason),\n                                    std::move(group_extra_prefix_string));\n  stream_output->unpacked = true;\n  n->postproc_states.stream_outputs = {std::move(stream_output)};\n  data_ = std::move(n);\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/request_state.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/request_state.h\n * \\brief The data structure maintaining the generation states of user requests.\n */\n#ifndef MLC_LLM_SERVE_REQUEST_STATE_H_\n#define MLC_LLM_SERVE_REQUEST_STATE_H_\n\n#include <tvm/ffi/container/array.h>\n#include <tvm/runtime/object.h>\n#include <tvm/runtime/tensor.h>\n#include <xgrammar/xgrammar.h>\n\n#include <optional>\n\n#include \"../support/random.h\"\n#include \"../tokenizers/streamer.h\"\n#include \"config.h\"\n#include \"metrics.h\"\n#include \"request.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing namespace tvm::runtime;\n\n/*!\n * \\brief The state of a request with regard to some single model.\n * \\details In MLC LLM, the serving engine may leverage multiple models\n * to fulfill a user generation request (e.g., use speculation decoding).\n * For each request, we isolate its states (e.g. the generated tokens)\n * on each model. This is to say, we use RequestModelState to store\n * the state of a user request on a single model (rather than all models).\n */\nclass RequestModelStateNode : public Object {\n public:\n  /*! \\brief The request that this state corresponds to. */\n  Request request;\n  /*!\n   * \\brief The internal request id of this state.\n   * It is the **physical index** of the request in the running request queue.\n   * If the request is on hold (not in the running queue), the request id\n   * should be -1.\n   */\n  int64_t internal_id = -1;\n  /*! \\brief The corresponding model id of this state. */\n  int model_id = -1;\n  /*!\n   * \\brief The committed generated token ids and related probability info.\n   * A token is \"committed\" means it will no longer be updated (or changed).\n   */\n  std::vector<SampleResult> committed_tokens;\n  /*! \\brief The list of input data yet for the model to prefill. */\n  Array<Data> inputs;\n  /*! \\brief The list of prefilled input data, used to notify prefix cache. */\n  std::vector<Data> prefilled_inputs;\n  /*! \\brief The number of tokens already cached in prefix cache. */\n  int64_t cached_committed_tokens = 0;\n  /*! \\brief The number of tokens that is already prefilled from the inputs. */\n  int64_t num_prefilled_tokens = 0;\n  /*! \\brief The number of tokens that need to be processed in the next decoding. */\n  int num_tokens_for_next_decode = 0;\n  /*! \\brief Whether retokenization is needed in the next decoding. When the jump-forward decoding\n   * is enabled, retokenization is needed after every jump-forward and decoding action. */\n  bool require_retokenization_in_next_decode = false;\n\n  // NOTE: The following fields are reserved for future speculative inference\n  // settings, and are produced by the speculative small models.\n  /*!\n   * \\brief The draft generated token ids and related probability info,\n   * which are usually generated by \"small\" speculative models.\n   * These tokens will be fed to a \"large\" model to determine the final\n   * result of speculation.\n   */\n  std::vector<SampleResult> draft_output_tokens;\n  /*! \\brief The storage slots for the associated states of draft tokens. */\n  std::vector<int> draft_token_slots;\n  /*! \\brief The parent indices of the draft tokens. */\n  std::vector<int64_t> draft_token_parent_idx;\n  /*! \\brief The first child indices of the draft tokens. */\n  std::vector<int64_t> draft_token_first_child_idx;\n\n  /*! \\brief The appeared committed and draft tokens and their occurrence times. */\n  std::unordered_map<int32_t, int32_t> appeared_token_ids;\n\n  /*!\n   * \\brief The current state of the generated token matching the grammar. Used in grammar-guided\n   * generation, otherwise it's std::nullopt.\n   */\n  std::optional<xgrammar::GrammarMatcher> grammar_matcher;\n\n  /*! \\brief Return the total length of the input data. */\n  int GetInputLength() const;\n  /*!\n   * \\brief Return whether the next token bitmask is required, i.e. the grammar-guided generation is\n   * enabled.\n   */\n  bool RequireNextTokenBitmask();\n  /*!\n   * \\brief Find the next token bitmask and store it in the given DLTensor.\n   * \\param bitmask The DLTensor to store the next token bitmask. The bitmask should be a tensor\n   * with dtype uint32_t and shape (ceildiv(vocab_size, 32),).\n   */\n  void GetNextTokenBitmask(DLTensor* bitmask);\n  /*! \\brief Commit a new token into committed_tokens. Does not effect the kv cache. Update\n   * appeared_token_ids and the grammar state. */\n  void CommitToken(SampleResult sampled_token);\n  /*! \\brief Roll back the last tokens back from committed_tokens. Does not effect the kv cache.\n   * Also roll back appeared_token_ids and the grammar state. */\n  void RollbackTokens(int count);\n\n  /*! \\brief Add a draft token into draft_output_tokens. Update appeared_token_ids. */\n  void AddDraftToken(SampleResult sampled_token, int draft_token_slot, int64_t parent_idx);\n  /*! \\brief Remove all draft tokens from draft_output_tokens. Update appeared_token_ids. */\n  void RemoveAllDraftTokens(std::vector<int>* removed_draft_token_slots = nullptr);\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<RequestModelStateNode>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO(\"mlc.serve.RequestModelState\", RequestModelStateNode, Object);\n};\n\nclass RequestModelState : public ObjectRef {\n public:\n  explicit RequestModelState(Request request, int model_id, int64_t internal_id, Array<Data> inputs,\n                             const std::optional<xgrammar::CompiledGrammar>& compiled_grammar);\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RequestModelState, ObjectRef, RequestModelStateNode);\n};\n\nstruct DeltaRequestReturn {\n  std::vector<int64_t> delta_token_ids;\n  std::vector<String> delta_logprob_json_strs;\n  Optional<String> finish_reason;\n  /*! \\brief The extra string to prepend the delta output. The delta output should be\n   * extra_prefix_string + detokenize(delta_token_ids). */\n  String extra_prefix_string = \"\";\n};\n\n/****************** Request States ******************/\n\n/*!\n * \\brief For each request, we maintain its \"request state\" in the\n * engine. Generally, the state of a request contains the information\n * of the request's generation at the current moment, including\n * the generated token ids, the grammar handler, etc.\n *\n * When a request has multiple parallel generations (e.g., the field\n * `n` of its generation config is more than 1), each generation will\n * have different states all the time.\n *\n * Therefore, to better support parallel generations, we denote the\n * state of a single generation as a \"RequestStateEntry\" instance,\n * and denote the state of a request's all generations using a vector,\n * named as a \"RequestState\" instance.\n *\n * A request's all state entries are organized as a tree structure\n * when there are parallel generations.\n * - the request input has the root status entry,\n * - each parallel generation is a child of the root.\n * This tree structure may be further extended to more complicated\n * cases in the future. As of now, for the case of `n > 1`, there\n * will be (n + 1) entries in total. In a \"RequestState\", the root\n * entry always has index 0. And we guarantee that the entry order\n * from the vector begin to the end is always a topological order\n * of the tree.\n */\n\n/*! \\brief Request state status. */\nenum class RequestStateStatus : int {\n  kPending = 0,\n  kAlive = 1,\n  kFinished = 2,\n};\n\n/*! \\brief The data structures for each request used in the action post-process. */\nstruct RequestActionPostProcWorkspace {\n  std::vector<RequestStreamOutput> stream_outputs;\n\n  RequestStreamOutput GetStreamOutput();\n};\n\n// forward declare request state node.\nclass RequestStateNode;\n\n/*!\n * \\brief A request's state entry. It contains the state of a single\n * generation of a request, or the state of a prompt prefix of a request.\n */\nclass RequestStateEntryNode : public Object {\n public:\n  /*! \\brief The status of the request state entry. */\n  RequestStateStatus status;\n  /*! \\brief The request that this state corresponds to. */\n  Request request;\n  /*!\n   * \\brief The idx of the parent request state entry of this state.\n   * Being -1 means the state has no parent and is the foremost\n   * \"prefix\" entry or the only entry.\n   */\n  int parent_idx = -1;\n  /*! \\brief The children indices of the request state entry. */\n  std::vector<int> child_indices;\n\n  /*!\n   * \\brief The state with regard to each model.\n   * \\sa RequestModelState\n   */\n  Array<RequestModelState> mstates;\n  /*! \\brief The random number generator of this request state entry. */\n  RandomGenerator rng;\n  /*! \\brief The stop string handler of this request state entry. */\n  StopStrHandler stop_str_handler;\n  /*!\n   * \\brief The start position of the committed tokens in the\n   * next request stream callback invocation.\n   */\n  int next_callback_token_pos;\n\n  /*! \\brief The extra string to prepend the output. */\n  std::string extra_prefix_string;\n\n  std::vector<int32_t> token_ids_for_prefix_cache_update;\n\n  /*!\n   * \\brief Back reference to the request state.\n   * Use ObjectRef to avoid circulate reference.\n   */\n  RequestStateNode* rstate = nullptr;\n\n  /*!\n   * \\brief Get the delta token ids and the logprob JSON strings for this request to return since\n   * the last time calling into this function, and return the finish reason if the request\n   * generation has finished.\n   * \\note This function follows the destination passing style, which means it writes the\n   * output into the \"idx\"-th slot in \"delta_stream_output\".\n   * We adopt the destination passing style to reduce the CPU data structure allocation and\n   * construction overhead.\n   * \\param tokenizer The tokenizer for logprob process.\n   * \\param max_single_sequence_length The maximum allowed single sequence length.\n   * \\param delta_stream_output The delta token ids to return, the logprob JSON strings\n   * of each delta token id, and the optional finish reason.\n   * \\param idx The index denoting which slot to write results in \"delta_request_return\".\n   */\n  void GetDeltaRequestReturn(const Tokenizer& tokenizer, int64_t max_single_sequence_length,\n                             RequestStreamOutput* delta_stream_output, int idx);\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<RequestStateEntryNode>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO_FINAL(\"mlc.serve.RequestStateEntry\", RequestStateEntryNode, Object);\n};\n\nclass RequestStateEntry : public ObjectRef {\n public:\n  explicit RequestStateEntry(Request request, int num_models, int64_t internal_id, int rng_seed,\n                             const std::vector<std::string>& token_table,\n                             const std::optional<xgrammar::CompiledGrammar>& compiled_grammar,\n                             int parent_idx = -1);\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RequestStateEntry, ObjectRef, RequestStateEntryNode);\n};\n\n/*! \\brief A request's state, which groups all the request state entries. */\nclass RequestStateNode : public Object {\n public:\n  /*! \\brief the request state entries */\n  std::vector<RequestStateEntry> entries;\n  /*! \\brief tracks the request metrics. */\n  RequestMetrics metrics;\n  /*!\n   * \\brief The post-process data structures.\n   * We make it a state to avoid repetitive memory allocation/free in the action post process.\n   */\n  RequestActionPostProcWorkspace postproc_states;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<RequestStateNode>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO_FINAL(\"mlc.serve.RequestState\", RequestStateNode, Object);\n};\n\nclass RequestState : public ObjectRef {\n public:\n  /*!\n   * \\brief Request state constructor. We take the number of response (namely \"n\" in OpenAI\n   * API) to pre-allocate all the data structure, in order to reduce the CPU data structure\n   * allocation overhead when updating the request state.\n   */\n  explicit RequestState(std::vector<RequestStateEntry> entries, int num_response,\n                        std::chrono::high_resolution_clock::time_point add_time_point);\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(RequestState, ObjectRef, RequestStateNode);\n};\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_REQUEST_STATE_H_\n"
  },
  {
    "path": "cpp/serve/sampler/cpu_sampler.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/sampler/cpu_sampler.cc\n * \\brief The implementation for CPU sampler functions.\n */\n#include <tvm/ffi/function.h>\n#include <tvm/runtime/tensor.h>\n#include <tvm/runtime/threading_backend.h>\n\n#include <algorithm>\n#include <cmath>\n\n#include \"../../support/random.h\"\n#include \"sampler.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nTVM_FFI_STATIC_INIT_BLOCK() { SamplerObj::RegisterReflection(); }\n\n/*!\n * \\brief Sample a value from the input probability distribution with top-p.\n * The input is a batch of distributions, and we use `unit_offset` to specify\n * which distribution to sample from.\n * \\param prob The input batch of probability distributions.\n * \\param unit_offset The offset specifying which distribution to output\n * \\param input_prob_offset The offset specifying which distribution to sample from.\n * \\param top_p The top-p value of sampling.\n * \\param uniform_sample The random number in [0, 1] for sampling.\n * \\return The sampled value and probability.\n * \\note This function is an enhancement of SampleTopPFromProb in TVM Unity.\n * We will upstream the enhancement after it gets stable.\n */\nTokenProbPair SampleTopPFromProb(Tensor prob, int unit_offset, int input_prob_offset, double top_p,\n                                 double uniform_sample) {\n  // prob: (*, v)\n  // The prob array may have arbitrary ndim and shape.\n  // The last dimension corresponds to the prob distribution size.\n  // We use the `unit_offset` parameter to determine which slice\n  // of the prob array we sample from.\n\n  TVM_FFI_ICHECK(prob.IsContiguous());\n  TVM_FFI_ICHECK(prob.DataType() == DataType::Float(32));\n  TVM_FFI_ICHECK_EQ(prob->device.device_type, DLDeviceType::kDLCPU);\n\n  int64_t ndata = prob->shape[prob->ndim - 1];\n  const float* __restrict p_prob =\n      static_cast<float*>(__builtin_assume_aligned(prob->data, 4)) + (input_prob_offset * ndata);\n  constexpr double one = 1.0f - 1e-5f;\n\n  if (top_p == 0) {\n    // Specially handle case where top_p == 0.\n    // This case is equivalent to doing argmax.\n    int argmax_pos = -1;\n    float max_prob = 0.0;\n    float sum_prob = 0.0;\n    for (int i = 0; i < ndata; ++i) {\n      if (p_prob[i] > max_prob) {\n        max_prob = p_prob[i];\n        argmax_pos = i;\n      }\n      // Early exit.\n      sum_prob += p_prob[i];\n      if (1 - sum_prob <= max_prob) {\n        break;\n      }\n    }\n    return {argmax_pos, 1.0};\n  }\n\n  if (top_p >= one) {\n    // Specially handle case where top_p == 1.\n    double prob_sum = 0.0f;\n    for (int64_t i = 0; i < ndata; ++i) {\n      prob_sum += p_prob[i];\n      if (prob_sum >= uniform_sample) {\n        return {i, p_prob[i]};\n      }\n    }\n    TVM_FFI_ICHECK(false) << \"Possibly prob distribution contains NAN.\";\n  }\n\n  // Key observation: when we are doing top_p sampling\n  // usually we only need to preserve some of the elements with\n  // high probabilities before we do sort\n  thread_local std::vector<std::pair<float, int>> data;\n\n  auto sample_top_p_with_filter = [&](float cuttoff) -> std::pair<float, int64_t> {\n    data.clear();\n    // filter the data with cuttoff\n    float cutoff_sum = 0.0f;\n    for (int64_t i = 0; i < ndata; ++i) {\n      if (p_prob[i] >= cuttoff) {\n        cutoff_sum += p_prob[i];\n        data.emplace_back(std::make_pair(p_prob[i], static_cast<int>(i)));\n        if (cutoff_sum > 1 - cuttoff) {\n          // Short cut. When the remaining parts cannot have total\n          // probability larger than cutoff, we can quit.\n          break;\n        }\n      }\n    }\n    if (data.size() == 0) return std::make_pair(-1, -1);\n    auto fcmp = [](const std::pair<float, int>& lhs, const std::pair<float, int>& rhs) {\n      return lhs.first > rhs.first;\n    };\n    std::sort(data.begin(), data.end(), fcmp);\n\n    // short cut, if we know that\n    // uniform sample < p[0] / top_p\n    // we know that unform_sample < p[0] / top_p_sum\n    // because top_p_sum guarantees to be smaller than top_p\n    // so we can simply return the argmax sample\n    // without computing anything\n    if (uniform_sample < data[0].first / top_p) {\n      return std::make_pair(data[0].first, data[0].second);\n    }\n\n    // compute top_p_sum\n    float cum_sum_prob = 0.0f;\n    float top_p_sum = 0.0f;\n    for (auto it = data.begin(); it != data.end(); ++it) {\n      float prob = it->first;\n      if (cum_sum_prob < top_p) {\n        top_p_sum += prob;\n      } else {\n        // we get to the right cutoff pt\n        break;\n      }\n      cum_sum_prob += prob;\n      it->first = cum_sum_prob;\n    }\n    // we find that the current total sum by the given cutoff\n    // is not sufficient to cover everything\n    // this means we might need to retry a smaller cutoff pt.\n    if (cum_sum_prob < top_p && cuttoff != 0.0f) return std::make_pair(-1, -1);\n\n    float last_cum_sum_prob = 0.0;\n    for (auto it = data.begin(); it != data.end(); ++it) {\n      if (uniform_sample < it->first / top_p_sum) {\n        return std::make_pair(it->first - last_cum_sum_prob, it->second);\n      }\n      last_cum_sum_prob = it->first;\n    }\n    return std::make_pair(data[static_cast<int64_t>(data.size()) - 1].first - last_cum_sum_prob,\n                          data[static_cast<int64_t>(data.size()) - 1].second);\n  };\n\n  if (top_p < 1) {\n    // sample through cutoff by a number\n    // by pigeonhole principle we will get at most 1024 elements\n    // usually it is much less by applying this filtering(order of 10 - 20)\n    data.reserve(256);\n    std::pair<float, int64_t> sampled_index = sample_top_p_with_filter(top_p / 1024);\n    if (sampled_index.second >= 0) return {sampled_index.second, sampled_index.first};\n  }\n  // fallback via full prob, rare case\n  data.reserve(ndata);\n  std::pair<float, int64_t> sampled_index = sample_top_p_with_filter(0.0f);\n  TVM_FFI_ICHECK_GE(sampled_index.second, 0);\n  return {sampled_index.second, sampled_index.first};\n}\n\n/*!\n * \\brief Renormalize the probability distribution by the top p value.\n * \\param prob The input batch of probability distributions.\n * \\param unit_offset The offset specifying which distribution to output\n * \\param top_p The top p value for renormalization.\n * \\param eps A small epsilon value for comparison stability.\n */\nvoid RenormalizeProbByTopP(Tensor prob, int unit_offset, double top_p, double eps) {\n  // prob: (*, v)\n  // The prob array may have arbitrary ndim and shape.\n  // The last dimension corresponds to the prob distribution size.\n  // We use the `unit_offset` parameter to determine which slice\n  // of the prob array we will renormalize.\n  TVM_FFI_ICHECK(prob.IsContiguous());\n  TVM_FFI_ICHECK(prob.DataType() == DataType::Float(32));\n  TVM_FFI_ICHECK_EQ(prob->device.device_type, DLDeviceType::kDLCPU);\n\n  if (top_p == 1.0) {\n    // No renormalization is needed if top_p is 1.\n    return;\n  }\n\n  int vocab_size = prob->shape[prob->ndim - 1];\n  float* __restrict p_prob =\n      static_cast<float*>(__builtin_assume_aligned(prob->data, 4)) + (unit_offset * vocab_size);\n\n  // We manually choice the cutoff values of \"top_p / 256\" and \"top_p / 8192\".\n  // In most of the cases, only one round is needed.\n  std::vector<double> cutoff_values{top_p / 256, top_p / 8192, 0.0f};\n\n  // Create the upper partition vector and the lower partition rolling vectors.\n  std::vector<float> upper_partition;\n  std::vector<float> lower_partitions[2];\n  upper_partition.reserve(vocab_size);\n  lower_partitions[0].reserve(vocab_size);\n  lower_partitions[1].reserve(vocab_size);\n  float upper_partition_sum = 0.0;\n  for (int round = 0; round < static_cast<int>(cutoff_values.size()); ++round) {\n    const float* lower_partition_begin;\n    const float* lower_partition_end;\n    if (round == 0) {\n      lower_partition_begin = p_prob;\n      lower_partition_end = p_prob + vocab_size;\n    } else {\n      int idx = (round - 1) & 1;\n      lower_partition_begin = lower_partitions[idx].data();\n      lower_partition_end = lower_partitions[idx].data() + lower_partitions[idx].size();\n    }\n\n    // - Partition the last round lower partition into upper and lower\n    // based on the new cutoff value.\n    std::vector<float>& lower_partition = lower_partitions[round & 1];\n    lower_partition.clear();\n    for (const float* ptr = lower_partition_begin; ptr != lower_partition_end; ++ptr) {\n      if (*ptr >= cutoff_values[round]) {\n        upper_partition.push_back(*ptr);\n        upper_partition_sum += *ptr;\n      } else {\n        lower_partition.push_back(*ptr);\n      }\n    }\n    // - If the upper partition sum is at least top p, exit the loop.\n    if (upper_partition_sum >= top_p - eps) {\n      break;\n    }\n  }\n\n  // - Sort the upper partition in descending order.\n  std::sort(upper_partition.begin(), upper_partition.end(), std::greater<>());\n  // - Find the top p boundary prob value.\n  float boundary_value = -1.0;\n  upper_partition_sum = 0.0;\n  for (float upper_value : upper_partition) {\n    upper_partition_sum += upper_value;\n    if (upper_partition_sum >= top_p - eps) {\n      boundary_value = upper_value;\n      break;\n    }\n  }\n  // - Mask all values smaller than the boundary to 0.\n  float renormalize_sum = 0.0;\n  std::vector<int> upper_partition_indices;\n  upper_partition_indices.reserve(vocab_size);\n  for (int i = 0; i < vocab_size; ++i) {\n    if (p_prob[i] >= boundary_value) {\n      upper_partition_indices.push_back(i);\n      renormalize_sum += p_prob[i];\n    } else {\n      p_prob[i] = 0.0;\n    }\n  }\n  // - Renormalize.\n  for (int idx : upper_partition_indices) {\n    p_prob[idx] /= renormalize_sum;\n  }\n}\n\nnamespace detail {\n\n/*! \\brief Implementation of getting top probs on CPU. */\ntemplate <int num_top_probs>\nstd::vector<TokenProbPair> ComputeTopProbsImpl(const float* p_prob, int ndata) {\n  std::vector<TokenProbPair> top_probs;\n  top_probs.reserve(num_top_probs);\n  for (int i = 0; i < num_top_probs; ++i) {\n    top_probs.emplace_back(-1, -1.0f);\n  }\n\n  float sum_prob = 0.0;\n  // Selection argsort.\n  for (int p = 0; p < ndata; ++p) {\n    int i = num_top_probs - 1;\n    for (; i >= 0; --i) {\n      if (p_prob[p] > top_probs[i].second) {\n        if (i != num_top_probs - 1) {\n          top_probs[i + 1] = top_probs[i];\n        }\n      } else {\n        break;\n      }\n    }\n    if (i != num_top_probs - 1) {\n      top_probs[i + 1] = {p, p_prob[p]};\n    }\n\n    // Early exit.\n    sum_prob += p_prob[p];\n    if (1 - sum_prob <= top_probs[num_top_probs - 1].second) {\n      break;\n    }\n  }\n  return top_probs;\n}\n\n}  // namespace detail\n\n/*! \\brief Get the probs of a few number of tokens with top probabilities. */\ninline std::vector<TokenProbPair> ComputeTopProbs(Tensor prob, int unit_offset, int num_top_probs) {\n  TVM_FFI_ICHECK_LE(num_top_probs, 5);\n  TVM_FFI_ICHECK_EQ(prob->ndim, 2);\n  int ndata = prob->shape[1];\n  const float* __restrict p_prob =\n      static_cast<float*>(__builtin_assume_aligned(prob->data, 4)) + (unit_offset * ndata);\n  switch (num_top_probs) {\n    case 0:\n      return {};\n    case 1:\n      return detail::ComputeTopProbsImpl<1>(p_prob, ndata);\n    case 2:\n      return detail::ComputeTopProbsImpl<2>(p_prob, ndata);\n    case 3:\n      return detail::ComputeTopProbsImpl<3>(p_prob, ndata);\n    case 4:\n      return detail::ComputeTopProbsImpl<4>(p_prob, ndata);\n    case 5:\n      return detail::ComputeTopProbsImpl<5>(p_prob, ndata);\n  }\n  throw;\n}\n\n/********************* CPU Sampler *********************/\n\nclass CPUSampler : public SamplerObj {\n public:\n  explicit CPUSampler(Optional<EventTraceRecorder> trace_recorder)\n      : trace_recorder_(std::move(trace_recorder)) {}\n\n  Tensor BatchRenormalizeProbsByTopP(Tensor probs_on_device,                  //\n                                     const std::vector<int>& sample_indices,  //\n                                     const Array<String>& request_ids,        //\n                                     const Array<GenerationConfig>& generation_cfg) final {\n    // probs_on_device: (n, v)\n    TVM_FFI_ICHECK_EQ(probs_on_device->ndim, 2);\n    // - Copy probs to CPU\n    RECORD_EVENT(trace_recorder_, request_ids, \"start copy probs to CPU\");\n    Tensor probs_on_host = CopyProbsToCPU(probs_on_device);\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish copy probs to CPU\");\n    int num_samples = sample_indices.size();\n    int num_probs = probs_on_device->shape[0];\n    int vocab_size = probs_on_device->shape[1];\n    TVM_FFI_ICHECK_EQ(request_ids.size(), num_samples);\n    TVM_FFI_ICHECK_EQ(generation_cfg.size(), num_samples);\n\n    std::vector<int> top_p_indices;\n    std::vector<double> top_p_values;\n    for (int i = 0; i < num_samples; ++i) {\n      if (top_p_indices.empty() || top_p_indices.back() != sample_indices[i]) {\n        top_p_indices.push_back(sample_indices[i]);\n        top_p_values.push_back(generation_cfg[i]->top_p);\n      } else {\n        TVM_FFI_ICHECK(fabs(top_p_values.back() - generation_cfg[i]->top_p) < eps_)\n            << \"Sampler requires the top_p values for each prob distribution are the same.\";\n      }\n    }\n    if (top_p_indices.empty()) {\n      // Return if no top p needs to apply.\n      return probs_on_host;\n    }\n\n    tvm::runtime::parallel_for_with_threading_backend(\n        [this, &probs_on_host, &request_ids, &top_p_indices, &top_p_values](int i) {\n          RECORD_EVENT(this->trace_recorder_, request_ids[i], \"start renormalize by top p\");\n          RenormalizeProbByTopP(probs_on_host, top_p_indices[i], top_p_values[i], eps_);\n          RECORD_EVENT(this->trace_recorder_, request_ids[i], \"finish renormalize by top p\");\n        },\n        0, static_cast<int64_t>(top_p_indices.size()));\n\n    return probs_on_host;\n  }\n\n  std::vector<SampleResult> BatchSampleTokensWithProbBeforeTopP(\n      Tensor probs_on_device,                         //\n      const std::vector<int>& sample_indices,         //\n      const Array<String>& request_ids,               //\n      const Array<GenerationConfig>& generation_cfg,  //\n      const std::vector<RandomGenerator*>& rngs) final {\n    // probs_on_device: (n, v)\n    TVM_FFI_ICHECK_EQ(probs_on_device->ndim, 2);\n    // - Copy probs to CPU\n    RECORD_EVENT(trace_recorder_, request_ids, \"start copy probs to CPU\");\n    Tensor probs_on_host = CopyProbsToCPU(probs_on_device);\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish copy probs to CPU\");\n\n    return BatchSampleTokensImpl(probs_on_host, sample_indices, request_ids, generation_cfg, rngs,\n                                 /*top_p_applied=*/false);\n  }\n\n  std::vector<SampleResult> BatchSampleTokensWithProbAfterTopP(\n      Tensor probs_on_host,                           //\n      const std::vector<int>& sample_indices,         //\n      const Array<String>& request_ids,               //\n      const Array<GenerationConfig>& generation_cfg,  //\n      const std::vector<RandomGenerator*>& rngs) final {\n    return BatchSampleTokensImpl(probs_on_host, sample_indices, request_ids, generation_cfg, rngs,\n                                 /*top_p_applied=*/true);\n  }\n\n  std::pair<std::vector<std::vector<SampleResult>>, std::vector<int>>\n  BatchVerifyDraftTokensWithProbAfterTopP(\n      Tensor probs_on_host, const Array<String>& request_ids,\n      const std::vector<int>& cum_verify_lengths, const Array<GenerationConfig>& generation_cfg,\n      const std::vector<RandomGenerator*>& rngs,\n      const std::vector<std::vector<SampleResult>>& draft_output_tokens,\n      const std::vector<int64_t>& token_tree_parent_ptr, Tensor draft_probs_on_device) final {\n    // probs_on_host: (n, v)\n    RECORD_EVENT(trace_recorder_, request_ids, \"start draft verification\");\n    TVM_FFI_ICHECK_EQ(probs_on_host->ndim, 2);\n\n    int num_sequence = static_cast<int>(cum_verify_lengths.size()) - 1;\n    TVM_FFI_ICHECK_EQ(rngs.size(), num_sequence);\n    TVM_FFI_ICHECK_EQ(draft_output_tokens.size(), num_sequence);\n\n    Tensor draft_probs_on_host = draft_probs_on_device.CopyTo(DLDevice{kDLCPU, 0});\n    std::vector<std::vector<SampleResult>> sample_results;\n    sample_results.resize(num_sequence);\n\n    float* __restrict global_p_probs =\n        static_cast<float*>(__builtin_assume_aligned(probs_on_host->data, 4));\n    int vocab_size = probs_on_host->shape[1];\n\n    std::vector<int> last_accepted_tree_node(num_sequence, 0);\n    tvm::runtime::parallel_for_with_threading_backend(\n        [&](int i) {\n          int verify_start = cum_verify_lengths[i];\n          int verify_end = cum_verify_lengths[i + 1];\n\n          TVM_FFI_ICHECK_EQ(token_tree_parent_ptr[verify_start], -1);\n          for (int j = verify_start + 1; j < verify_end; ++j) {\n            TVM_FFI_ICHECK_EQ(token_tree_parent_ptr[j], j - verify_start - 1)\n                << \"CPU sampler only supports chain-style draft tokens.\";\n          }\n\n          int cur_token_idx = 0;\n          // Sub 1 to ignore the last prediction.\n          for (; cur_token_idx < verify_end - verify_start - 1; ++cur_token_idx) {\n            float* p_probs = global_p_probs + (verify_start + cur_token_idx) * vocab_size;\n            int cur_token = draft_output_tokens[i][cur_token_idx].GetTokenId();\n            float q_value = draft_output_tokens[i][cur_token_idx].sampled_token_id.second;\n            float p_value = p_probs[cur_token];\n\n            if (p_value >= q_value) {\n              sample_results[i].push_back(\n                  SampleResult{{cur_token, p_value},\n                               ComputeTopProbs(probs_on_host, verify_start + cur_token_idx,\n                                               generation_cfg[i]->top_logprobs)});\n              continue;\n            }\n            float r = rngs[i]->GetRandomNumber();\n            if (r < p_value / (q_value + eps_)) {\n              sample_results[i].push_back(\n                  SampleResult{{cur_token, p_value},\n                               ComputeTopProbs(probs_on_host, verify_start + cur_token_idx,\n                                               generation_cfg[i]->top_logprobs)});\n              continue;\n            }\n\n            // normalize a new probability distribution\n            double sum_v = 0.0;\n            const float* __restrict p_qdist =\n                static_cast<float*>(__builtin_assume_aligned(draft_probs_on_host->data, 4)) +\n                (verify_start + cur_token_idx + 1) * vocab_size;\n\n            for (int j = 0; j < vocab_size; ++j) {\n              p_probs[j] = std::max(p_probs[j] - p_qdist[j], 0.0f);\n              sum_v += p_probs[j];\n            }\n            for (int j = 0; j < vocab_size; ++j) {\n              p_probs[j] /= sum_v;\n            }\n\n            // sample a new token from the new distribution\n            SampleResult sample_result;\n            sample_result.sampled_token_id = SampleTopPFromProb(\n                probs_on_host, verify_start + cur_token_idx, verify_start + cur_token_idx,\n                /*top_p=*/1.0f, rngs[i]->GetRandomNumber());\n            sample_result.top_prob_tokens = ComputeTopProbs(\n                probs_on_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs);\n            sample_results[i].push_back(sample_result);\n            break;\n          }\n          last_accepted_tree_node[i] = cur_token_idx;\n          // if cur_token_idx == verify_end - verify_start - 1\n          // all draft tokens are accepted\n          // we sample a new token\n          if (cur_token_idx == verify_end - verify_start - 1) {\n            SampleResult sample_result;\n            // sample a new token from the original distribution\n            sample_result.sampled_token_id = SampleTopPFromProb(\n                probs_on_host, verify_start + cur_token_idx, verify_start + cur_token_idx,\n                /*top_p=*/1.0f, rngs[i]->GetRandomNumber());\n            sample_result.top_prob_tokens = ComputeTopProbs(\n                probs_on_host, verify_start + cur_token_idx, generation_cfg[i]->top_logprobs);\n            sample_results[i].push_back(sample_result);\n          }\n        },\n        0, num_sequence);\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish draft verification\");\n    return {sample_results, last_accepted_tree_node};\n  }\n\n private:\n  std::vector<SampleResult> BatchSampleTokensImpl(Tensor probs_on_host,                           //\n                                                  const std::vector<int>& sample_indices,         //\n                                                  const Array<String>& request_ids,               //\n                                                  const Array<GenerationConfig>& generation_cfg,  //\n                                                  const std::vector<RandomGenerator*>& rngs,      //\n                                                  bool top_p_applied) {\n    // probs_on_host: (n, v)\n    RECORD_EVENT(trace_recorder_, request_ids, \"start sampling\");\n    TVM_FFI_ICHECK_EQ(probs_on_host->ndim, 2);\n    TVM_FFI_ICHECK_EQ(probs_on_host->device.device_type, DLDeviceType::kDLCPU);\n\n    // - Sample tokens from probabilities.\n    int n = request_ids.size();\n    TVM_FFI_ICHECK_EQ(generation_cfg.size(), n);\n    TVM_FFI_ICHECK_EQ(rngs.size(), n);\n\n    std::vector<SampleResult> sample_results;\n    sample_results.resize(n);\n\n    tvm::runtime::parallel_for_with_threading_backend(\n        [this, &sample_results, &probs_on_host, &generation_cfg, &rngs, &request_ids, top_p_applied,\n         sample_indices](int i) {\n          RECORD_EVENT(this->trace_recorder_, request_ids[i], \"start sample token\");\n          // Sample top p from probability.\n          double top_p =\n              top_p_applied\n                  ? 1.0f\n                  : (generation_cfg[i]->temperature < eps_ ? 0.0 : generation_cfg[i]->top_p);\n          sample_results[i].sampled_token_id = SampleTopPFromProb(\n              probs_on_host, i, sample_indices[i], top_p, rngs[i]->GetRandomNumber());\n          sample_results[i].top_prob_tokens =\n              ComputeTopProbs(probs_on_host, i, generation_cfg[i]->top_logprobs);\n          RECORD_EVENT(this->trace_recorder_, request_ids[i], \"finish sample token\");\n        },\n        0, n);\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish sampling\");\n    return sample_results;\n  }\n\n  /*! \\brief Copy prob distributions from device to CPU. */\n  Tensor CopyProbsToCPU(Tensor probs_on_device) {\n    // probs_on_device: (n, v)\n    if (probs_on_device->device.device_type == kDLCPU) {\n      return probs_on_device;\n    }\n\n    TVM_FFI_ICHECK(probs_on_device->device.device_type != kDLCPU);\n    if (probs_host_.defined()) {\n      TVM_FFI_ICHECK_EQ(probs_host_->shape[1], probs_on_device->shape[1]);\n    }\n\n    int64_t init_size = probs_host_.defined() ? probs_host_->shape[0] : 32;\n    int64_t num_tokens = probs_on_device->shape[0];\n    int64_t vocab_size = probs_on_device->shape[1];\n    while (init_size < num_tokens) {\n      init_size *= 2;\n    }\n    if (!probs_host_.defined() || init_size != probs_host_->shape[0]) {\n      probs_host_ =\n          Tensor::Empty({init_size, vocab_size}, probs_on_device->dtype, DLDevice{kDLCPU, 0});\n    }\n    TVM_FFI_ICHECK_LE(num_tokens, probs_host_->shape[0]);\n    Tensor view = probs_host_.CreateView({num_tokens, vocab_size}, probs_on_device->dtype);\n    view.CopyFrom(probs_on_device);\n    return view;\n  }\n\n  /*! \\brief The event trace recorder for requests. */\n  Optional<EventTraceRecorder> trace_recorder_;\n  /*! \\brief Customized function which computes prob distribution from logits */\n  Function flogits_to_probs_inplace_;\n  /*! \\brief Probability distribution array on CPU. */\n  Tensor probs_host_{nullptr};\n  const float eps_ = 1e-5;\n};\n\nSampler Sampler::CreateCPUSampler(Optional<EventTraceRecorder> trace_recorder) {\n  return Sampler(tvm::ffi::make_object<CPUSampler>(std::move(trace_recorder)));\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/sampler/gpu_sampler.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/sampler/gpu_sampler.cc\n * \\brief The implementation for GPU sampler functions.\n */\n#include <tvm/ffi/function.h>\n#include <tvm/runtime/device_api.h>\n#include <tvm/runtime/nvtx.h>\n#include <tvm/runtime/tensor.h>\n\n#include \"../../support/random.h\"\n#include \"sampler.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\ninline bool FlashInferSamplingAvailable(Device device) {\n  // Device must be CUDA, and FlashInfer must be enabled.\n  if (device.device_type != DLDeviceType::kDLCUDA ||\n      !Function::GetGlobal(\"flashinfer.sampling.parallel_sampling_from_prob\").has_value()) {\n    return false;\n  }\n  // Compute version must be at least 8.0\n  Any rv;\n  DeviceAPI::Get(device)->GetAttr(device, kComputeVersion, &rv);\n  std::string compute_version = rv.cast<std::string>();\n  std::string major_version = compute_version.substr(0, compute_version.find('.'));\n  return std::stoi(major_version) >= 8;\n}\n\ninline void CopyArray(Tensor src, Tensor dst, TVMStreamHandle copy_stream) {\n  DLTensor dl_dst = *(dst.operator->());\n  Tensor::CopyFromTo(src.operator->(), &dl_dst, copy_stream);\n}\n\ninline void SyncCopyStream(Device device, TVMStreamHandle compute_stream,\n                           TVMStreamHandle copy_stream) {\n  // - If there is no particular copy stream, no action is needed.\n  if (copy_stream == nullptr) {\n    return;\n  }\n  // - Sync two streams.\n  DeviceAPI::Get(device)->SyncStreamFromTo(device, copy_stream, compute_stream);\n}\n\n/*********************** GPU Sampler ***********************/\n\nclass GPUSampler : public SamplerObj {\n public:\n  explicit GPUSampler(int max_num_sample, int vocab_size, FunctionTable* ft, DLDevice device,\n                      Optional<EventTraceRecorder> trace_recorder)\n      : max_num_sample_(max_num_sample),\n        vocab_size_(vocab_size),\n        flashinfer_sampling_available_(FlashInferSamplingAvailable(device)),\n        device_(device),\n        gpu_multinomial_from_uniform_func_(ft->gpu_multinomial_from_uniform_func_),\n        gpu_argsort_probs_func_(ft->gpu_argsort_probs_func_),\n        gpu_sample_with_top_p_func_(ft->gpu_sample_with_top_p_func_),\n        gpu_sampler_take_probs_func_(ft->gpu_sampler_take_probs_func_),\n        gpu_verify_draft_tokens_func_(ft->gpu_verify_draft_tokens_func_),\n        gpu_renormalize_by_top_p_func_(ft->gpu_renormalize_by_top_p_func_),\n        trace_recorder_(std::move(trace_recorder)) {\n    TVM_FFI_ICHECK(gpu_multinomial_from_uniform_func_.defined());\n    TVM_FFI_ICHECK(gpu_argsort_probs_func_.defined());\n    TVM_FFI_ICHECK(gpu_sample_with_top_p_func_.defined());\n    TVM_FFI_ICHECK(gpu_sampler_take_probs_func_.defined());\n\n    flashinfer_multinomial_sample_func_ =\n        Function::GetGlobal(\"flashinfer.sampling.parallel_sampling_from_prob\");\n\n    Device preferred_host_device = GetPreferredHostDevice(device);\n    // We support at most 5 top prob results for each sequence.\n    // Initialize auxiliary arrays on CPU.\n    uniform_samples_host_ = Tensor::Empty({max_num_sample}, dtype_f32_, preferred_host_device);\n    sample_indices_host_ = Tensor::Empty({max_num_sample}, dtype_i32_, preferred_host_device);\n    top_p_host_ = Tensor::Empty({max_num_sample}, dtype_f32_, preferred_host_device);\n    top_p_init_pivots_host_ = Tensor::Empty({max_num_sample, num_top_p_cutoff_pivots_}, dtype_f32_,\n                                            preferred_host_device);\n    top_prob_offsets_host_ = Tensor::Empty({max_num_sample * 5}, dtype_i32_, preferred_host_device);\n    draft_tokens_host_ = Tensor::Empty({max_num_sample}, dtype_i32_, preferred_host_device);\n    token_tree_first_child_host_ =\n        Tensor::Empty({max_num_sample}, dtype_i32_, preferred_host_device);\n    token_tree_next_sibling_host_ =\n        Tensor::Empty({max_num_sample}, dtype_i32_, preferred_host_device);\n    token_tree_parent_ptr_host_ =\n        Tensor::Empty({max_num_sample}, dtype_i32_, preferred_host_device);\n    sampled_token_ids_host_ = Tensor::Empty({max_num_sample}, dtype_i32_, preferred_host_device);\n    sampled_probs_host_ = Tensor::Empty({max_num_sample}, dtype_f32_, preferred_host_device);\n    top_prob_probs_host_ = Tensor::Empty({max_num_sample * 5}, dtype_f32_, preferred_host_device);\n    top_prob_indices_host_ = Tensor::Empty({max_num_sample * 5}, dtype_i32_, preferred_host_device);\n    // Initialize auxiliary arrays on GPU.\n    uniform_samples_device_ = Tensor::Empty({max_num_sample}, dtype_f32_, device);\n    sample_indices_device_ = Tensor::Empty({max_num_sample}, dtype_i32_, device);\n    top_p_device_ = Tensor::Empty({max_num_sample}, dtype_f32_, device);\n    top_p_init_pivots_device_ =\n        Tensor::Empty({max_num_sample, num_top_p_cutoff_pivots_}, dtype_f32_, device);\n    top_prob_offsets_device_ = Tensor::Empty({max_num_sample * 5}, dtype_i32_, device);\n    draft_tokens_device_ = Tensor::Empty({max_num_sample}, dtype_i32_, device);\n    token_tree_first_child_device_ = Tensor::Empty({max_num_sample}, dtype_i32_, device);\n    token_tree_next_sibling_device_ = Tensor::Empty({max_num_sample}, dtype_i32_, device);\n    token_tree_parent_ptr_device_ = Tensor::Empty({max_num_sample}, dtype_i32_, device);\n    sampled_token_ids_device_ = Tensor::Empty({max_num_sample}, dtype_i32_, device);\n\n    // If the device is CUDA/ROCm, we create a standalone copy stream, in\n    // purpose to hide the latency of auxiliary stream copy.\n    if (device.device_type == DLDeviceType::kDLCUDA ||\n        device.device_type == DLDeviceType::kDLROCM) {\n      // The compute stream is the default stream.\n      compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device);\n      copy_stream_ = DeviceAPI::Get(device)->CreateStream(device);\n    }\n  }\n\n  ~GPUSampler() {\n    // Free the copy stream if defined.\n    if (copy_stream_ != nullptr) {\n      DeviceAPI::Get(device_)->FreeStream(device_, copy_stream_);\n    }\n  }\n\n  Tensor BatchRenormalizeProbsByTopP(Tensor probs_on_device,                  //\n                                     const std::vector<int>& sample_indices,  //\n                                     const Array<String>& request_ids,        //\n                                     const Array<GenerationConfig>& generation_cfg) final {\n    NVTXScopedRange nvtx_scope(\"BatchRenormalizeProbsByTopP\");\n    // probs_on_device: (n, v)\n    RECORD_EVENT(trace_recorder_, request_ids, \"start renormalization by top p\");\n    TVM_FFI_ICHECK_EQ(probs_on_device->ndim, 2);\n    int num_samples = sample_indices.size();\n    int num_probs = probs_on_device->shape[0];\n    int vocab_size = probs_on_device->shape[1];\n    TVM_FFI_ICHECK_LE(num_probs, max_num_sample_);\n    TVM_FFI_ICHECK_EQ(generation_cfg.size(), num_samples);\n\n    // - Check if there is need for applying top p.\n    bool need_top_p = CheckTopP(generation_cfg, sample_indices, num_probs, num_samples, vocab_size);\n    if (!need_top_p) {\n      return probs_on_device;\n    }\n\n    // - Copy auxiliary array for top-p and initial pivots.\n    Tensor top_p_host = top_p_host_.CreateView({num_probs}, dtype_f32_);\n    Tensor top_p_device = top_p_device_.CreateView({num_probs}, dtype_f32_);\n    CopyArray(/*src=*/top_p_host, /*dst=*/top_p_device, copy_stream_);\n\n    Tensor top_p_init_pivots_host =\n        top_p_init_pivots_host_.CreateView({num_probs, num_top_p_cutoff_pivots_}, dtype_f32_);\n    Tensor top_p_init_pivots_device =\n        top_p_init_pivots_device_.CreateView({num_probs, num_top_p_cutoff_pivots_}, dtype_f32_);\n    const float* p_top_p = static_cast<const float*>(top_p_host->data);\n    float* p_top_p_init_pivots = static_cast<float*>(top_p_init_pivots_host->data);\n    for (int i = 0; i < num_probs; ++i) {\n      if (1 - p_top_p[i] >= 0.02) {\n        p_top_p_init_pivots[i * num_top_p_cutoff_pivots_] =\n            std::min(1 - p_top_p[i], static_cast<float>(0.5));\n        p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 1] = 0.02;\n        p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 2] = 0.01;\n      } else {\n        p_top_p_init_pivots[i * num_top_p_cutoff_pivots_] = 1 - p_top_p[i];\n        p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 1] = (1 - p_top_p[i]) / 2;\n        p_top_p_init_pivots[i * num_top_p_cutoff_pivots_ + 2] = (1 - p_top_p[i]) / 4;\n      }\n    }\n    CopyArray(/*src=*/top_p_init_pivots_host, /*dst=*/top_p_init_pivots_device, copy_stream_);\n    SyncCopyStream(device_, compute_stream_, copy_stream_);\n\n    // - Renormalize the prob with top p.\n    Tensor renormed_probs_on_device =\n        gpu_renormalize_by_top_p_func_(probs_on_device, top_p_device, top_p_init_pivots_device)\n            .cast<Tensor>();\n\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish renormalization by top p\");\n    return renormed_probs_on_device;\n  }\n\n  std::vector<SampleResult> BatchSampleTokensWithProbBeforeTopP(\n      Tensor probs_on_device,                         //\n      const std::vector<int>& sample_indices,         //\n      const Array<String>& request_ids,               //\n      const Array<GenerationConfig>& generation_cfg,  //\n      const std::vector<RandomGenerator*>& rngs) final {\n    NVTXScopedRange nvtx_scope(\"BatchSampleTokensWithProbBeforeTopP\");\n    return BatchSampleTokensImpl(std::move(probs_on_device), sample_indices, request_ids,\n                                 generation_cfg, rngs, /*top_p_applied=*/false);\n  }\n\n  std::vector<SampleResult> BatchSampleTokensWithProbAfterTopP(\n      Tensor probs_on_device,                         //\n      const std::vector<int>& sample_indices,         //\n      const Array<String>& request_ids,               //\n      const Array<GenerationConfig>& generation_cfg,  //\n      const std::vector<RandomGenerator*>& rngs) final {\n    NVTXScopedRange nvtx_scope(\"BatchSampleTokensWithProbAfterTopP\");\n    return BatchSampleTokensImpl(std::move(probs_on_device), sample_indices, request_ids,\n                                 generation_cfg, rngs, /*top_p_applied=*/true);\n  }\n\n  std::pair<std::vector<std::vector<SampleResult>>, std::vector<int>>\n  BatchVerifyDraftTokensWithProbAfterTopP(\n      Tensor probs_on_device, const Array<String>& request_ids,\n      const std::vector<int>& cum_verify_lengths, const Array<GenerationConfig>& generation_cfg,\n      const std::vector<RandomGenerator*>& rngs,\n      const std::vector<std::vector<SampleResult>>& draft_output_tokens,\n      const std::vector<int64_t>& token_tree_parent_ptr, Tensor draft_probs_on_device) final {\n    NVTXScopedRange nvtx_scope(\"BatchVerifyDraftTokensWithProbAfterTopP\");\n    std::vector<std::vector<SampleResult>> sample_results;\n    // probs_on_device: (n, v)\n    RECORD_EVENT(trace_recorder_, request_ids, \"start draft verification\");\n    TVM_FFI_ICHECK_EQ(probs_on_device->ndim, 2);\n\n    int num_sequence = static_cast<int>(cum_verify_lengths.size()) - 1;\n    TVM_FFI_ICHECK_EQ(rngs.size(), num_sequence);\n    TVM_FFI_ICHECK_EQ(draft_output_tokens.size(), num_sequence);\n    sample_results.resize(num_sequence);\n\n    int num_nodes = cum_verify_lengths.back();\n    TVM_FFI_ICHECK(num_nodes <= max_num_sample_);\n    TVM_FFI_ICHECK_EQ(draft_probs_on_device->shape[0], num_nodes);\n    Tensor uniform_samples_device = GenerateUniformSamples(rngs, cum_verify_lengths);\n    Tensor draft_tokens_host = draft_tokens_host_.CreateView({num_nodes}, dtype_i32_);\n    Tensor draft_tokens_device = draft_tokens_device_.CreateView({num_nodes}, dtype_i32_);\n\n    // Copy draft tokens to GPU\n    int* p_draft_tokens_host = static_cast<int*>(draft_tokens_host->data);\n    for (int i = 0; i < num_sequence; i++) {\n      const std::vector<SampleResult>& draft_output_tokens_i = draft_output_tokens[i];\n      int start = cum_verify_lengths[i];\n      int end = cum_verify_lengths[i + 1];\n      // start/end is the range of the sequence i in probs_on_device, which includes the prob dist\n      // of the draft tokens and the last committed token\n      TVM_FFI_ICHECK_EQ(draft_output_tokens_i.size() + 1, end - start);\n      for (int j = 0; j < end - start - 1; j++) {\n        // Copy sampled token id\n        p_draft_tokens_host[start + j + 1] = draft_output_tokens_i[j].GetTokenId();\n      }\n    }\n    CopyArray(draft_tokens_host, draft_tokens_device, copy_stream_);\n\n    Tensor token_tree_first_child_host =\n        token_tree_first_child_host_.CreateView({num_nodes}, dtype_i32_);\n    Tensor token_tree_first_child_device =\n        token_tree_first_child_device_.CreateView({num_nodes}, dtype_i32_);\n    Tensor token_tree_next_sibling_host =\n        token_tree_next_sibling_host_.CreateView({num_nodes}, dtype_i32_);\n    Tensor token_tree_next_sibling_device =\n        token_tree_next_sibling_device_.CreateView({num_nodes}, dtype_i32_);\n    Tensor token_tree_parent_ptr_host =\n        token_tree_parent_ptr_host_.CreateView({num_sequence}, dtype_i32_);\n    Tensor token_tree_parent_ptr_device =\n        token_tree_parent_ptr_device_.CreateView({num_sequence}, dtype_i32_);\n    std::vector<int> token_tree_child_to_parent(/*n=*/num_nodes);\n\n    int* token_tree_first_child_ptr_host = static_cast<int*>(token_tree_first_child_host->data);\n    int* token_tree_next_sibling_ptr_host = static_cast<int*>(token_tree_next_sibling_host->data);\n    // Build the tree structure on CPU\n    for (int i = 0; i < num_sequence; i++) {\n      // Assuming no tree structure for now\n      int start = cum_verify_lengths[i];\n      int end = cum_verify_lengths[i + 1];\n      TVM_FFI_ICHECK_GE(end - start, 2);\n      for (int j = 0; j < end - start; j++) {\n        int cur_node = j + start;\n        int parent_node =\n            token_tree_parent_ptr[cur_node] != -1 ? token_tree_parent_ptr[cur_node] + start : -1;\n        token_tree_first_child_ptr_host[cur_node] = -1;\n        if (parent_node != -1 && token_tree_first_child_ptr_host[parent_node] == -1) {\n          token_tree_first_child_ptr_host[parent_node] = cur_node;\n        }\n        token_tree_child_to_parent[cur_node] = parent_node;\n        if (cur_node + 1 < end && token_tree_parent_ptr[cur_node - start + 1] ==\n                                      token_tree_parent_ptr[cur_node - start]) {\n          token_tree_next_sibling_ptr_host[cur_node] = cur_node + 1;\n        } else {\n          token_tree_next_sibling_ptr_host[cur_node] = -1;\n        }\n      }\n      static_cast<int*>(token_tree_parent_ptr_host->data)[i] = start;  // point to the root\n    }\n    // Copy token tree structure to GPU\n    CopyArray(token_tree_first_child_host, token_tree_first_child_device, copy_stream_);\n    CopyArray(token_tree_next_sibling_host, token_tree_next_sibling_device, copy_stream_);\n    CopyArray(token_tree_parent_ptr_host, token_tree_parent_ptr_device, copy_stream_);\n\n    SyncCopyStream(device_, compute_stream_, copy_stream_);\n\n    gpu_verify_draft_tokens_func_(draft_probs_on_device, draft_tokens_device, probs_on_device,\n                                  token_tree_first_child_device, token_tree_next_sibling_device,\n                                  uniform_samples_device, token_tree_parent_ptr_device);\n\n    DeviceAPI::Get(device_)->SyncStreamFromTo(device_, compute_stream_, copy_stream_);\n    CopyArray(token_tree_parent_ptr_device, token_tree_parent_ptr_host, copy_stream_);\n\n    std::vector<SampleResult> additional_sample_result;\n    {\n      additional_sample_result.reserve(num_sequence);\n      // Sample one additional token for each sequence using the probablity at the last accepted\n      // token.\n      uniform_samples_device = GenerateUniformSamples(rngs, num_sequence);\n      const Tensor& sample_indices_device = token_tree_parent_ptr_device;\n      // Check need_prob_values\n      bool need_prob_values = false;\n      for (int i = 0; i < num_sequence; i++) {\n        need_prob_values |= generation_cfg[i]->logprobs;\n      }\n      std::vector<int> top_prob_offset_indptr;\n      if (!need_prob_values) {\n        top_prob_offset_indptr.resize(num_sequence + 1, 0);\n      } else {\n        // Slow path: if any of the generation config requires prob values, we need to copy\n        // sample_indices to host to compute top_prob_offset_indptr.\n        DeviceAPI::Get(device_)->StreamSync(device_, copy_stream_);\n        std::vector<int> sample_indices;\n        sample_indices.reserve(num_sequence);\n        const int* p_token_tree_parent_ptr = static_cast<int*>(token_tree_parent_ptr_host->data);\n        for (int i = 0; i < num_sequence; i++) {\n          sample_indices.push_back(p_token_tree_parent_ptr[i]);\n        }\n        CheckProbValues(generation_cfg, sample_indices, num_nodes, num_sequence, vocab_size_,\n                        &top_prob_offset_indptr);\n      }\n      auto device_arrays =\n          SampleOnGPU(probs_on_device, uniform_samples_device, sample_indices_device,\n                      /*need_top_p=*/false, need_prob_values, num_nodes, top_prob_offset_indptr);\n      auto host_arrays = CopyArraysToCPU(device_arrays, num_sequence, need_prob_values,\n                                         top_prob_offset_indptr.back());\n      additional_sample_result =\n          CollectSampleResult(host_arrays, num_sequence, need_prob_values, top_prob_offset_indptr);\n    }\n\n    std::vector<int> last_accepted_tree_node;\n    last_accepted_tree_node.reserve(num_sequence);\n    for (int i = 0; i < num_sequence; i++) {\n      int start = cum_verify_lengths[i];\n      int end = cum_verify_lengths[i + 1];\n      int last_accepted = static_cast<int*>(token_tree_parent_ptr_host->data)[i];\n      last_accepted_tree_node.push_back(last_accepted - start);\n      int num_accepted = 0;\n      for (int cur_node = last_accepted; cur_node != start;\n           cur_node = token_tree_child_to_parent[cur_node]) {\n        sample_results[i].push_back(draft_output_tokens[i][cur_node - start - 1]);\n        num_accepted++;\n      }\n      std::reverse(sample_results[i].rbegin(), sample_results[i].rbegin() + num_accepted);\n    }\n\n    // Append the additional sample result to the sample_results\n    TVM_FFI_ICHECK_EQ(additional_sample_result.size(), num_sequence);\n    for (int i = 0; i < num_sequence; i++) {\n      sample_results[i].push_back(additional_sample_result[i]);\n    }\n\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish draft verification\");\n    return {sample_results, last_accepted_tree_node};\n  }\n\n private:\n  std::vector<SampleResult> BatchSampleTokensImpl(Tensor probs_on_device,                         //\n                                                  const std::vector<int>& sample_indices,         //\n                                                  const Array<String>& request_ids,               //\n                                                  const Array<GenerationConfig>& generation_cfg,  //\n                                                  const std::vector<RandomGenerator*>& rngs,      //\n                                                  bool top_p_applied) {\n    // probs_on_device: (n, v)\n    RECORD_EVENT(trace_recorder_, request_ids, \"start sampling\");\n    TVM_FFI_ICHECK_EQ(probs_on_device->ndim, 2);\n    TVM_FFI_ICHECK_EQ(probs_on_device->device.device_id, device_.device_id);\n    TVM_FFI_ICHECK_EQ(probs_on_device->device.device_type, device_.device_type);\n    int num_samples = sample_indices.size();\n    int num_probs = probs_on_device->shape[0];\n    int vocab_size = probs_on_device->shape[1];\n    if (num_samples == 0) {\n      // This synchronization is necessary for making sure that this round\n      // of model forward is finished.\n      DeviceAPI::Get(device_)->StreamSync(device_, compute_stream_);\n      return {};\n    }\n    TVM_FFI_ICHECK_EQ(request_ids.size(), num_samples);\n    TVM_FFI_ICHECK_EQ(generation_cfg.size(), num_samples);\n    TVM_FFI_ICHECK_EQ(rngs.size(), num_samples);\n\n    // Since `num_samples` may be larger than `max_num_sample_` in some cases,\n    // we apply chunking to support large `num_samples`.\n    std::vector<SampleResult> sample_results;\n    if (num_samples <= max_num_sample_) {\n      sample_results = ChunkSampleTokensImpl(probs_on_device, sample_indices, generation_cfg, rngs,\n                                             top_p_applied);\n    } else {\n      for (int chunk_start = 0; chunk_start < num_samples; chunk_start += max_num_sample_) {\n        int chunk_end = std::min(chunk_start + max_num_sample_, num_samples);\n        std::vector<int> sample_indices_chunk(sample_indices.begin() + chunk_start,\n                                              sample_indices.begin() + chunk_end);\n        Array<GenerationConfig> generation_cfg_chunk(generation_cfg.begin() + chunk_start,\n                                                     generation_cfg.begin() + chunk_end);\n        std::vector<RandomGenerator*> rngs_chunk(rngs.begin() + chunk_start,\n                                                 rngs.begin() + chunk_end);\n        std::vector<SampleResult> sample_results_chunk = ChunkSampleTokensImpl(\n            probs_on_device, sample_indices_chunk, generation_cfg_chunk, rngs_chunk, top_p_applied);\n        sample_results.insert(sample_results.end(), sample_results_chunk.begin(),\n                              sample_results_chunk.end());\n      }\n    }\n\n    RECORD_EVENT(trace_recorder_, request_ids, \"finish sampling\");\n    return sample_results;\n  }\n\n  /*! \\brief Collect the sampling results from the computed Tensor results. */\n  std::vector<SampleResult> CollectSampleResult(const std::vector<Tensor>& host_arrays,\n                                                int num_samples, bool need_prob_values,\n                                                const std::vector<int> top_prob_offset_indptr) {\n    const int* p_sampled_token_ids = static_cast<const int*>(host_arrays[0]->data);\n    const float* p_sampled_probs = nullptr;\n    const float* p_top_prob_probs = nullptr;\n    const int* p_top_prob_indices = nullptr;\n    if (need_prob_values) {\n      p_sampled_probs = static_cast<const float*>(host_arrays[1]->data);\n      p_top_prob_probs = static_cast<const float*>(host_arrays[2]->data);\n      p_top_prob_indices = static_cast<const int*>(host_arrays[3]->data);\n    }\n    std::vector<SampleResult> sample_results;\n    sample_results.reserve(num_samples);\n    TVM_FFI_ICHECK_EQ(top_prob_offset_indptr.size(), num_samples + 1);\n    for (int i = 0; i < num_samples; ++i) {\n      // Note: we set the probability in SampleResult to 1.0 since prob value is not needed.\n      float sampled_prob = need_prob_values ? p_sampled_probs[i] : 1.0;\n      std::vector<TokenProbPair> top_prob_tokens;\n      top_prob_tokens.reserve(top_prob_offset_indptr[i + 1] - top_prob_offset_indptr[i]);\n      for (int j = top_prob_offset_indptr[i]; j < top_prob_offset_indptr[i + 1]; ++j) {\n        top_prob_tokens.emplace_back(p_top_prob_indices[j], p_top_prob_probs[j]);\n      }\n      sample_results.push_back(\n          SampleResult{{p_sampled_token_ids[i], sampled_prob}, top_prob_tokens});\n    }\n    return sample_results;\n  }\n\n  std::vector<SampleResult> ChunkSampleTokensImpl(Tensor probs_on_device,                         //\n                                                  const std::vector<int>& sample_indices,         //\n                                                  const Array<GenerationConfig>& generation_cfg,  //\n                                                  const std::vector<RandomGenerator*>& rngs,      //\n                                                  bool top_p_applied) {\n    // probs_on_device: (n, v)\n    int num_samples = sample_indices.size();\n    int num_probs = probs_on_device->shape[0];\n    int vocab_size = probs_on_device->shape[1];\n\n    // - Generate random numbers.\n    //   Copy the random numbers and sample indices.\n    auto uniform_samples_device = GenerateUniformSamples(rngs, num_samples);\n    auto sample_indices_device = CopySampleIndicesToGPU(sample_indices);\n\n    // - Check if there is need for applying top p or prob values,\n    //   so that argsort is needed.\n    bool need_top_p = false;\n    if (!top_p_applied) {\n      need_top_p = CheckTopP(generation_cfg, sample_indices, num_probs, num_samples, vocab_size);\n    }\n    // The indptr array of the number of top probs for each sample.\n    std::vector<int> top_prob_offset_indptr;\n    bool need_prob_values = CheckProbValues(generation_cfg, sample_indices, num_probs, num_samples,\n                                            vocab_size, &top_prob_offset_indptr);\n\n    // - Sample tokens on GPU, and take out the probability values if needed.\n    std::vector<Tensor> device_arrays =\n        SampleOnGPU(probs_on_device, uniform_samples_device, sample_indices_device, need_top_p,\n                    need_prob_values, num_probs, top_prob_offset_indptr);\n\n    // - Copy the GPU sampling function results to CPU.\n    std::vector<Tensor> host_arrays = CopyArraysToCPU(device_arrays, num_samples, need_prob_values,\n                                                      top_prob_offset_indptr.back());\n\n    // - Collect the sampling results.\n    return CollectSampleResult(host_arrays, num_samples, need_prob_values, top_prob_offset_indptr);\n  }\n\n  /*! \\brief Generate num_samples uniform random numbers, and copy them to GPU. */\n  Tensor GenerateUniformSamples(const std::vector<RandomGenerator*>& rngs, int num_samples) {\n    float* p_uniform_samples = static_cast<float*>(uniform_samples_host_->data);\n    for (int i = 0; i < num_samples; ++i) {\n      p_uniform_samples[i] = rngs[i]->GetRandomNumber();\n    }\n    Tensor uniform_samples_host = uniform_samples_host_.CreateView({num_samples}, dtype_f32_);\n    Tensor uniform_samples_device = uniform_samples_device_.CreateView({num_samples}, dtype_f32_);\n    CopyArray(/*src=*/uniform_samples_host, /*dst=*/uniform_samples_device, copy_stream_);\n    return uniform_samples_device;\n  }\n\n  /*! \\brief Generate uniform random numbers, and copy the numbers and sample indices to GPU. The\n   * number of samples for each random generator is given by `cum_num_samples`. */\n  Tensor GenerateUniformSamples(const std::vector<RandomGenerator*>& rngs,\n                                const std::vector<int>& cum_num_samples) {\n    float* p_uniform_samples = static_cast<float*>(uniform_samples_host_->data);\n    int total_samples = cum_num_samples.back();\n    for (int i = 0; i + 1 < static_cast<int>(cum_num_samples.size()); ++i) {\n      for (int j = cum_num_samples[i]; j < cum_num_samples[i + 1]; ++j) {\n        p_uniform_samples[j] = rngs[i]->GetRandomNumber();\n      }\n    }\n    Tensor uniform_samples_host = uniform_samples_host_.CreateView({total_samples}, dtype_f32_);\n    Tensor uniform_samples_device = uniform_samples_device_.CreateView({total_samples}, dtype_f32_);\n    CopyArray(/*src=*/uniform_samples_host, /*dst=*/uniform_samples_device, copy_stream_);\n    return uniform_samples_device;\n  }\n\n  /*! \\brief Generate uniform random numbers, and copy the numbers and sample indices to GPU. */\n  Tensor CopySampleIndicesToGPU(const std::vector<int>& sample_indices) {\n    int* p_sample_indices = static_cast<int*>(sample_indices_host_->data);\n    std::copy(sample_indices.begin(), sample_indices.end(), p_sample_indices);\n    // Copy the sample indices to GPU.\n    int num_samples = static_cast<int>(sample_indices.size());\n    Tensor sample_indices_host = sample_indices_host_.CreateView({num_samples}, dtype_i32_);\n    Tensor sample_indices_device = sample_indices_device_.CreateView({num_samples}, dtype_i32_);\n    CopyArray(/*src=*/sample_indices_host, /*dst=*/sample_indices_device, copy_stream_);\n    return sample_indices_device;\n  }\n\n  /*! \\brief Check if top p is needed. Update host top p array in place. */\n  bool CheckTopP(const Array<GenerationConfig>& generation_cfg,\n                 const std::vector<int>& sample_indices, int num_probs, int num_samples,\n                 int vocab_size) {\n    // Initialize top p values with -1.\n    float* p_top_p = static_cast<float*>(top_p_host_->data);\n    for (int i = 0; i < num_probs; ++i) {\n      p_top_p[i] = -1.0;\n    }\n    bool need_top_p = false;\n    for (int i = 0; i < num_samples; ++i) {\n      if (p_top_p[sample_indices[i]] == -1.0) {\n        p_top_p[sample_indices[i]] = generation_cfg[i]->top_p;\n        need_top_p |= generation_cfg[i]->top_p != 1.0;\n      } else {\n        TVM_FFI_ICHECK(fabs(p_top_p[sample_indices[i]] - generation_cfg[i]->top_p) < eps_)\n            << \"GPU sampler requires the top_p values for each prob distribution are the same.\";\n      }\n    }\n    for (int i = 0; i < num_probs; ++i) {\n      p_top_p[i] = std::max(p_top_p[i], eps_);\n    }\n    return need_top_p;\n  }\n\n  /*! \\brief Check whether prob values are needed, and collect info when necessary. */\n  bool CheckProbValues(const Array<GenerationConfig>& generation_cfg,\n                       const std::vector<int>& sample_indices, int num_probs, int num_samples,\n                       int vocab_size, std::vector<int>* top_prob_offset_indptr) {\n    top_prob_offset_indptr->reserve(num_samples + 1);\n    top_prob_offset_indptr->push_back(0);\n    int* p_top_prob_offsets = static_cast<int*>(top_prob_offsets_host_->data);\n    int num_top_probs = 0;\n    bool need_prob_values = false;\n    for (int i = 0; i < num_samples; ++i) {\n      need_prob_values |= generation_cfg[i]->logprobs;\n      for (int j = 0; j < generation_cfg[i]->top_logprobs; ++j) {\n        p_top_prob_offsets[num_top_probs++] = sample_indices[i] * vocab_size + j;\n      }\n      top_prob_offset_indptr->push_back(top_prob_offset_indptr->back() +\n                                        generation_cfg[i]->top_logprobs);\n    }\n    TVM_FFI_ICHECK_EQ(num_top_probs, top_prob_offset_indptr->back());\n    return need_prob_values;\n  }\n\n  /*! \\brief Sample tokens on GPU. Take out the probability values when needed. */\n  std::vector<Tensor> SampleOnGPU(Tensor probs_on_device, Tensor uniform_samples_device,\n                                  Tensor sample_indices_device,  //\n                                  bool need_top_p, bool need_prob_values, int num_probs,\n                                  const std::vector<int>& top_prob_offset_indptr) {\n    Tensor sampled_token_ids_device{nullptr};\n    Tensor sampled_probs_device{nullptr};\n    Tensor top_prob_probs_device{nullptr};\n    Tensor top_prob_indices_device{nullptr};\n\n    if (!need_top_p && !need_prob_values) {\n      // - Short path: If top_p and prob values are not needed, we directly sample from multinomial.\n      SyncCopyStream(device_, compute_stream_, copy_stream_);\n      if (flashinfer_sampling_available_) {\n        sampled_token_ids_device =\n            sampled_token_ids_device_.CreateView({sample_indices_device->shape[0]}, dtype_i32_);\n        flashinfer_multinomial_sample_func_.value()(probs_on_device, uniform_samples_device,\n                                                    sample_indices_device,\n                                                    sampled_token_ids_device);\n      } else {\n        sampled_token_ids_device =\n            gpu_multinomial_from_uniform_func_(probs_on_device, uniform_samples_device,\n                                               sample_indices_device)\n                .cast<Tensor>();\n      }\n      return {sampled_token_ids_device, sampled_probs_device, top_prob_probs_device,\n              top_prob_indices_device};\n    }\n\n    // - Argsort the probability.\n    Array<Tensor> argsort_results = gpu_argsort_probs_func_(probs_on_device).cast<Array<Tensor>>();\n    TVM_FFI_ICHECK_EQ(argsort_results.size(), 2);\n    Tensor sorted_probs_on_device = argsort_results[0];\n    Tensor sorted_indices_on_device = argsort_results[1];\n\n    // - Copy auxiliary array for top-p and prob values in ahead.\n    Tensor top_p_device;\n    Tensor top_prob_offsets_device;\n    if (need_top_p) {\n      Tensor top_p_host = top_p_host_.CreateView({num_probs}, dtype_f32_);\n      top_p_device = top_p_device_.CreateView({num_probs}, dtype_f32_);\n      CopyArray(/*src=*/top_p_host, /*dst=*/top_p_device, copy_stream_);\n    }\n    if (need_prob_values) {\n      int num_top_probs = top_prob_offset_indptr.back();\n      Tensor top_prob_offsets_host = top_prob_offsets_host_.CreateView({num_top_probs}, dtype_i32_);\n      top_prob_offsets_device = top_prob_offsets_device_.CreateView({num_top_probs}, dtype_i32_);\n      CopyArray(/*src=*/top_prob_offsets_host, /*dst=*/top_prob_offsets_device, copy_stream_);\n    }\n    SyncCopyStream(device_, compute_stream_, copy_stream_);\n\n    if (need_top_p) {\n      // - Sample with top_p applied.\n      sampled_token_ids_device =\n          gpu_sample_with_top_p_func_(sorted_probs_on_device, sorted_indices_on_device,\n                                      uniform_samples_device, sample_indices_device, top_p_device)\n              .cast<Tensor>();\n    } else {\n      // - Sample without top_p.\n      if (flashinfer_sampling_available_) {\n        sampled_token_ids_device =\n            sampled_token_ids_device_.CreateView({sample_indices_device->shape[0]}, dtype_i32_);\n        flashinfer_multinomial_sample_func_\n            .value()(probs_on_device, uniform_samples_device, sample_indices_device,\n                     sampled_token_ids_device)\n            .cast<Tensor>();\n      } else {\n        sampled_token_ids_device =\n            gpu_multinomial_from_uniform_func_(probs_on_device, uniform_samples_device,\n                                               sample_indices_device)\n                .cast<Tensor>();\n      }\n    }\n\n    if (need_prob_values) {\n      // - Take the probability values.\n      Array<Tensor> prob_value_results =\n          gpu_sampler_take_probs_func_(probs_on_device, sorted_indices_on_device,\n                                       sample_indices_device, sampled_token_ids_device,\n                                       top_prob_offsets_device)\n              .cast<Array<Tensor>>();\n      sampled_probs_device = prob_value_results[0];\n      top_prob_probs_device = prob_value_results[1];\n      top_prob_indices_device = prob_value_results[2];\n    }\n\n    return {sampled_token_ids_device, sampled_probs_device, top_prob_probs_device,\n            top_prob_indices_device};\n  }\n\n  /*! \\brief Copy the results of GPU sampling functions back to CPU. */\n  std::vector<Tensor> CopyArraysToCPU(const std::vector<Tensor>& device_arrays,  //\n                                      int num_samples, bool need_prob_values, int num_top_probs) {\n    Tensor sampled_token_ids_device = device_arrays[0];\n    Tensor sampled_probs_device = device_arrays[1];\n    Tensor top_prob_probs_device = device_arrays[2];\n    Tensor top_prob_indices_device = device_arrays[3];\n    TVM_FFI_ICHECK(sampled_token_ids_device.defined());\n    TVM_FFI_ICHECK_EQ(sampled_token_ids_device->ndim, 1);\n    TVM_FFI_ICHECK_EQ(sampled_token_ids_device->shape[0], num_samples);\n    Tensor sampled_token_ids_host = sampled_token_ids_host_.CreateView({num_samples}, dtype_i32_);\n    CopyArray(/*src=*/sampled_token_ids_device, /*dst=*/sampled_token_ids_host, compute_stream_);\n\n    Tensor sampled_probs_host{nullptr};\n    Tensor top_prob_probs_host{nullptr};\n    Tensor top_prob_indices_host{nullptr};\n    if (need_prob_values) {\n      TVM_FFI_ICHECK(sampled_probs_device.defined());\n      TVM_FFI_ICHECK(top_prob_probs_device.defined());\n      TVM_FFI_ICHECK(top_prob_indices_device.defined());\n      TVM_FFI_ICHECK_EQ(sampled_probs_device->ndim, 1);\n      TVM_FFI_ICHECK_EQ(top_prob_probs_device->ndim, 1);\n      TVM_FFI_ICHECK_EQ(top_prob_indices_device->ndim, 1);\n      TVM_FFI_ICHECK_EQ(sampled_probs_device->shape[0], num_samples);\n      TVM_FFI_ICHECK_EQ(top_prob_probs_device->shape[0], num_top_probs);\n      TVM_FFI_ICHECK_EQ(top_prob_indices_device->shape[0], num_top_probs);\n      sampled_probs_host = sampled_probs_host_.CreateView({num_samples}, dtype_i32_);\n      top_prob_probs_host = top_prob_probs_host_.CreateView({num_top_probs}, dtype_f32_);\n      top_prob_indices_host = top_prob_indices_host_.CreateView({num_top_probs}, dtype_i32_);\n      CopyArray(/*src=*/sampled_probs_device, /*dst=*/sampled_probs_host, compute_stream_);\n      if (num_top_probs > 0) {\n        CopyArray(/*src=*/top_prob_probs_device, /*dst=*/top_prob_probs_host, compute_stream_);\n        CopyArray(/*src=*/top_prob_indices_device, /*dst=*/top_prob_indices_host, compute_stream_);\n      }\n    }\n\n    // Synchronize for CPU to get the correct array results.\n    DeviceAPI::Get(device_)->StreamSync(device_, compute_stream_);\n\n    return {sampled_token_ids_host, sampled_probs_host, top_prob_probs_host, top_prob_indices_host};\n  }\n\n  // Model configurations\n  const int max_num_sample_;\n  const int vocab_size_;\n  const DLDataType dtype_i32_ = DataType::Int(32);\n  const DLDataType dtype_f32_ = DataType::Float(32);\n  const bool flashinfer_sampling_available_;\n  // Functions for sampling on GPU.\n  Device device_;\n  Function gpu_multinomial_from_uniform_func_;\n  Function gpu_argsort_probs_func_;\n  Function gpu_sample_with_top_p_func_;\n  Function gpu_sampler_take_probs_func_;\n  Function gpu_verify_draft_tokens_func_;\n  Function gpu_renormalize_by_top_p_func_;\n  Optional<Function> flashinfer_multinomial_sample_func_;\n  // Auxiliary Tensors on CPU\n  Tensor uniform_samples_host_;\n  Tensor sample_indices_host_;\n  Tensor top_p_host_;\n  Tensor top_p_init_pivots_host_;\n  Tensor top_prob_offsets_host_;\n  Tensor draft_tokens_host_;\n  Tensor token_tree_first_child_host_;\n  Tensor token_tree_next_sibling_host_;\n  Tensor token_tree_parent_ptr_host_;\n  Tensor sampled_token_ids_host_;\n  Tensor sampled_probs_host_;\n  Tensor top_prob_probs_host_;\n  Tensor top_prob_indices_host_;\n  // Auxiliary Tensors on GPU\n  Tensor uniform_samples_device_;\n  Tensor sample_indices_device_;\n  Tensor top_p_device_;\n  Tensor top_p_init_pivots_device_;\n  Tensor top_prob_offsets_device_;\n  Tensor draft_tokens_device_;\n  Tensor token_tree_first_child_device_;\n  Tensor token_tree_next_sibling_device_;\n  Tensor token_tree_parent_ptr_device_;\n  Tensor sampled_token_ids_device_;\n  // The event trace recorder for requests. */\n  Optional<EventTraceRecorder> trace_recorder_;\n  // The device stream for the default computation operations.\n  TVMStreamHandle compute_stream_ = nullptr;\n  // The device stream for copying auxiliary data structure to GPU.\n  TVMStreamHandle copy_stream_ = nullptr;\n  const float eps_ = 1e-5;\n  const int num_top_p_cutoff_pivots_ = 3;\n};\n\nSampler Sampler::CreateGPUSampler(int max_num_sample, int vocab_size, FunctionTable* ft,\n                                  DLDevice device, Optional<EventTraceRecorder> trace_recorder) {\n  return Sampler(tvm::ffi::make_object<GPUSampler>(max_num_sample, vocab_size, ft, device,\n                                                   std::move(trace_recorder)));\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/sampler/sampler.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/sampler/sampler.h\n * \\brief The header for runtime module of sampler functions.\n */\n\n#ifndef MLC_LLM_SERVE_SAMPLER_SAMPLER_H_\n#define MLC_LLM_SERVE_SAMPLER_SAMPLER_H_\n\n#include <tvm/ffi/string.h>\n#include <tvm/runtime/module.h>\n\n#include \"../../base.h\"\n#include \"../../support/random.h\"\n#include \"../data.h\"\n#include \"../event_trace_recorder.h\"\n#include \"../model.h\"\n#include \"../request_state.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing tvm::Device;\nusing namespace tvm::runtime;\n\n/*!\n * \\brief The base class of runtime sampler.\n * Its main function is `BatchSampleTokensWithProbBeforeTopP`, which takes a batch of\n * logits and corresponding configuration, and sample one token\n * for each instance of the batch.\n */\nclass SamplerObj : public Object {\n public:\n  /*!\n   * \\brief Renormalize the input batch of probability distributions with top p values.\n   * \\param probs_on_device The batch of prob distributions before normalization.\n   * \\param sample_indices Specifying which request we will sample for\n   * in i-th output for the sampling later on.\n   * The output result of the sampling will be as follow:\n   *   result[i] = sample_from(prob_on_device[sample_indices[i],:], generation_config[i]));\n   * For renormalization, the sample indices are used for determine the top-p grouping.\n   * \\param request_ids The id of each request.\n   * \\param generation_cfg The generation config of each request in the input batch.\n   * \\return The renormalized probability distributions, residing on device\n   * if the sampler is GPU sampler, or on host if the sampler is CPU sampler.\n   */\n  virtual Tensor BatchRenormalizeProbsByTopP(Tensor probs_on_device,                  //\n                                             const std::vector<int>& sample_indices,  //\n                                             const Array<String>& request_ids,        //\n                                             const Array<GenerationConfig>& generation_cfg) = 0;\n\n  /*!\n   * \\brief Sample tokens from the input batch of prob distribution on device.\n   * The input prob distributions are not yet applied with top-p.\n   * \\param probs_on_device The prob distributions on GPU to sample tokens from.\n   * \\param sample_indices Specifying which request we should sample for\n   * in i-th output. The output result is sample as follow:\n   *   result[i] = sample_from(prob_on_device[sample_indices[i],:], generation_config[i]));\n   * \\param request_ids The id of each request.\n   * \\param generation_cfg The generation config of each request\n   * in the input batch.\n   * \\param rngs The random number generator of each sequence.\n   * \\return The batch of sampling results, which contain the sampled token id\n   * and other probability info.\n   */\n  virtual std::vector<SampleResult> BatchSampleTokensWithProbBeforeTopP(\n      Tensor probs_on_device,                         //\n      const std::vector<int>& sample_indices,         //\n      const Array<String>& request_ids,               //\n      const Array<GenerationConfig>& generation_cfg,  //\n      const std::vector<RandomGenerator*>& rngs) = 0;\n\n  /*!\n   * \\brief Sample tokens from the input batch of prob distribution on device.\n   * The input prob distributions are already applied with top-p.\n   * \\param probs The prob distributions.\n   * It resides on GPU if the sampler is GPU sampler, or on host if hte sampler is CPU sampler.\n   * \\param sample_indices Specifying which request we should sample for\n   * in i-th output. The output result is sample as follow:\n   *   result[i] = sample_from(prob_on_device[sample_indices[i],:], generation_config[i]));\n   * \\param request_ids The id of each request.\n   * \\param generation_cfg The generation config of each request\n   * in the input batch.\n   * \\param rngs The random number generator of each sequence.\n   * \\return The batch of sampling results, which contain the sampled token id\n   * and other probability info.\n   */\n  virtual std::vector<SampleResult> BatchSampleTokensWithProbAfterTopP(\n      Tensor probs,                                   //\n      const std::vector<int>& sample_indices,         //\n      const Array<String>& request_ids,               //\n      const Array<GenerationConfig>& generation_cfg,  //\n      const std::vector<RandomGenerator*>& rngs) = 0;\n\n  /*!\n   * \\brief Verify draft tokens generated by small models in the large model\n   * in speculative decoding. The input corresponds to a batch of sequences.\n   * The input prob distributions are already applied with top-p.\n   * \\param probs The prob distributions on GPU to sample tokens from.\n   * It resides on GPU if the sampler is GPU sampler, or on host if hte sampler is CPU sampler.\n   * \\param request_ids The id of each request.\n   * \\param cum_verify_lengths The cumulative draft lengths to verify of all sequences.\n   * \\param generation_cfg The generation config of each request\n   * in the input batch.\n   * \\param rngs The random number generator of each sequence.\n   * \\param draft_output_tokens The draft tokens generated by the small model for\n   * each sequence.\n   * \\param token_tree_parent_ptr The parent pointer of the token tree.\n   * \\param draft_probs_on_device The probability distribution computed from the\n   * small model for each sequence. Concatenated tensor of shape (total_verify_length, vocab_size).\n   * It includes the slot for the last committed token that has undefined probablity value.\n   * \\return The list of accepted tokens for each request and the index of the last accepted tree\n   * node for each request.\n   */\n  virtual std::pair<std::vector<std::vector<SampleResult>>, std::vector<int>>\n  BatchVerifyDraftTokensWithProbAfterTopP(\n      Tensor probs, const Array<String>& request_ids, const std::vector<int>& cum_verify_lengths,\n      const Array<GenerationConfig>& generation_cfg, const std::vector<RandomGenerator*>& rngs,\n      const std::vector<std::vector<SampleResult>>& draft_output_tokens,\n      const std::vector<int64_t>& token_tree_parent_ptr, Tensor draft_probs_on_device) = 0;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<SamplerObj>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO(\"mlc.serve.Sampler\", SamplerObj, Object);\n};\n\nclass Sampler : public ObjectRef {\n public:\n  /*! * \\brief Create a CPU sampler. */\n  static Sampler CreateCPUSampler(Optional<EventTraceRecorder> trace_recorder);\n  /*!\n   * \\brief Create a GPU sampler.\n   * \\param max_num_sample The max number of samples to sample at a time.\n   * \\param vocab_size The model's vocabulary size.\n   * \\param ft The packed function table.\n   * \\param device The device that the model runs on.\n   * \\param trace_recorder The event trace recorder.\n   */\n  static Sampler CreateGPUSampler(int max_num_sample, int vocab_size, FunctionTable* ft,\n                                  DLDevice device, Optional<EventTraceRecorder> trace_recorder);\n\n  /*! \\brief Check if the given device supports GPU sampling. */\n  static bool SupportGPUSampler(Device device) {\n    return device.device_type == DLDeviceType::kDLCUDA ||\n           device.device_type == DLDeviceType::kDLVulkan ||\n           device.device_type == DLDeviceType::kDLMetal;\n  }\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Sampler, ObjectRef, SamplerObj);\n};\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_SAMPLER_SAMPLER_H_\n"
  },
  {
    "path": "cpp/serve/threaded_engine.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/threaded_engine.cc\n * \\brief The implementation for threaded serving engine in MLC LLM.\n */\n#include \"threaded_engine.h\"\n\n#include <tvm/ffi/function.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/runtime/module.h>\n\n#include <atomic>\n#include <condition_variable>\n#include <mutex>\n#include <optional>\n\n#include \"../support/json_parser.h\"\n#include \"../support/result.h\"\n#include \"engine.h\"\n#include \"request.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing tvm::Device;\nusing namespace tvm::runtime;\n\n/*! \\brief The threaded engine instruction kind. */\nenum class InstructionKind : int {\n  kAddRequest = 0,\n  kAbortRequest = 1,\n  kUnloadEngine = 2,\n  kReloadEngine = 3,\n  kResetEngine = 4,\n  kDebugCallFuncOnAllAllWorker = 5,\n};\n\n/*! \\brief The implementation of ThreadedEngine. */\nclass ThreadedEngineImpl : public ThreadedEngine {\n public:\n  void InitThreadedEngine(Device device, Optional<Function> request_stream_callback,\n                          Optional<EventTraceRecorder> trace_recorder) final {\n    device_ = device;\n    TVM_FFI_ICHECK(request_stream_callback.defined())\n        << \"ThreadedEngine requires request stream callback function, but it is not given.\";\n    request_stream_callback_ = request_stream_callback.value();\n    trace_recorder_ = trace_recorder;\n  }\n\n  void Reload(String engine_config_json_str) final {\n    // NOTE: important to set this before, we send out\n    // reload instruction to the other threads\n    // otherwise there can be deadlocks\n    reload_finished_ = false;\n    bool need_notify = false;\n    {\n      std::lock_guard<std::mutex> lock(background_loop_mutex_);\n      instruction_queue_.emplace_back(InstructionKind::kReloadEngine,\n                                      std::move(engine_config_json_str));\n      ++pending_request_operation_cnt_;\n      need_notify = engine_waiting_;\n    }\n    if (need_notify) {\n      background_loop_cv_.notify_one();\n    }\n    {\n      std::unique_lock<std::mutex> lock(reload_unload_mutex_);\n      reload_unload_cv_.wait(lock, [this] { return reload_finished_; });\n    }\n  }\n\n  void Unload() final {\n    // NOTE: important to set this before, we send out\n    // reload instruction to the other threads\n    // otherwise there can be deadlocks\n    // e.g. the other thread finish unload job and set the flag to true\n    // then we set it back to false\n    unload_finished_ = false;\n    bool need_notify = false;\n    {\n      std::lock_guard<std::mutex> lock(background_loop_mutex_);\n      instruction_queue_.emplace_back(InstructionKind::kUnloadEngine, ObjectRef(nullptr));\n      ++pending_request_operation_cnt_;\n      need_notify = engine_waiting_;\n    }\n    if (need_notify) {\n      background_loop_cv_.notify_one();\n    }\n    {\n      std::unique_lock<std::mutex> lock(reload_unload_mutex_);\n      reload_unload_cv_.wait(lock, [this] { return unload_finished_; });\n    }\n  }\n\n  void Reset() final {\n    bool need_notify = false;\n    {\n      std::lock_guard<std::mutex> lock(background_loop_mutex_);\n      instruction_queue_.emplace_back(InstructionKind::kResetEngine, ObjectRef(nullptr));\n      ++pending_request_operation_cnt_;\n      need_notify = engine_waiting_;\n    }\n    if (need_notify) {\n      background_loop_cv_.notify_one();\n    }\n  }\n\n  void AddRequest(Request request) final {\n    bool need_notify = false;\n    {\n      std::lock_guard<std::mutex> lock(background_loop_mutex_);\n      instruction_queue_.emplace_back(InstructionKind::kAddRequest, request);\n      ++pending_request_operation_cnt_;\n      need_notify = engine_waiting_;\n    }\n    if (need_notify) {\n      background_loop_cv_.notify_one();\n    }\n  }\n\n  void AbortRequest(const String& request_id) final {\n    bool need_notify = false;\n    {\n      std::lock_guard<std::mutex> lock(background_loop_mutex_);\n      instruction_queue_.emplace_back(InstructionKind::kAbortRequest, request_id);\n      ++pending_request_operation_cnt_;\n      need_notify = engine_waiting_;\n    }\n    if (need_notify) {\n      background_loop_cv_.notify_one();\n    }\n  }\n\n  void RunBackgroundLoop() final {\n    // The local vectors that load the requests from critical regions.\n    std::vector<std::pair<InstructionKind, Any>> local_instruction_queue;\n\n    while (!exit_now_.load(std::memory_order_relaxed)) {\n      {\n        std::unique_lock<std::mutex> lock(background_loop_mutex_);\n        engine_waiting_ = true;\n        background_loop_cv_.wait(lock, [this] {\n          return (background_engine_ != nullptr && !background_engine_->Empty()) ||\n                 pending_request_operation_cnt_.load() > 0 ||\n                 exit_now_.load(std::memory_order_relaxed);\n        });\n        engine_waiting_ = false;\n        local_instruction_queue = instruction_queue_;\n        instruction_queue_.clear();\n        pending_request_operation_cnt_ = 0;\n      }\n      for (const auto& [kind, arg] : local_instruction_queue) {\n        if (kind == InstructionKind::kAddRequest) {\n          TVM_FFI_ICHECK(background_engine_ != nullptr) << \"Background engine is not loaded.\";\n          background_engine_->AddRequest(Downcast<Request>(arg));\n        } else if (kind == InstructionKind::kAbortRequest) {\n          // in a rare case, abort request can happen after unloading\n          // aka background engine is nullptr\n          // this happens when the on going generation was interrupted\n          // the engine get unloaded, and then abort was called.\n          // it is safe to ignore these abort in such case\n          if (background_engine_ != nullptr) {\n            background_engine_->AbortRequest(Downcast<String>(arg));\n          }\n        } else if (kind == InstructionKind::kUnloadEngine) {\n          EngineUnloadImpl();\n        } else if (kind == InstructionKind::kReloadEngine) {\n          EngineUnloadImpl();\n          EngineReloadImpl(Downcast<String>(arg));\n        } else if (kind == InstructionKind::kResetEngine) {\n          if (background_engine_ != nullptr) {\n            background_engine_->Reset();\n          }\n        } else if (kind == InstructionKind::kDebugCallFuncOnAllAllWorker) {\n          TVM_FFI_ICHECK(background_engine_ != nullptr) << \"Background engine is not loaded.\";\n          Array<Any> packed_args = Downcast<Array<Any>>(arg);\n          background_engine_->DebugCallFuncOnAllAllWorker(\n              Downcast<String>(packed_args[0]), Downcast<Optional<String>>(packed_args[1]));\n        } else {\n          LOG(FATAL) << \"Cannot reach here\";\n        }\n      }\n      if (background_engine_ != nullptr) {\n        background_engine_->Step();\n      }\n    }\n  }\n\n  void RunBackgroundStreamBackLoop() final {\n    // The local vectors that load the request stream callback inputs from critical regions.\n    std::vector<Array<RequestStreamOutput>> local_request_stream_callback_inputs;\n    std::vector<RequestStreamOutput> flattened_callback_inputs;\n\n    while (!exit_now_.load(std::memory_order_relaxed)) {\n      {\n        std::unique_lock<std::mutex> lock(request_stream_callback_mutex_);\n        stream_callback_waiting_ = true;\n        request_stream_callback_cv_.wait(lock, [this] {\n          return pending_request_stream_callback_cnt_.load() > 0 ||\n                 exit_now_.load(std::memory_order_relaxed);\n        });\n        stream_callback_waiting_ = false;\n\n        local_request_stream_callback_inputs = request_stream_callback_inputs_;\n        request_stream_callback_inputs_.clear();\n        pending_request_stream_callback_cnt_ = 0;\n      }\n      for (const Array<RequestStreamOutput>& callback_inputs :\n           local_request_stream_callback_inputs) {\n        for (const RequestStreamOutput& callback_input : callback_inputs) {\n          flattened_callback_inputs.push_back(callback_input);\n        }\n      }\n      if (!flattened_callback_inputs.empty()) {\n        request_stream_callback_(Array<RequestStreamOutput>(flattened_callback_inputs));\n      }\n      flattened_callback_inputs.clear();\n    }\n  }\n\n  void ExitBackgroundLoop() final {\n    {\n      std::lock_guard<std::mutex> lock(background_loop_mutex_);\n      exit_now_.store(true);\n    }\n    background_loop_cv_.notify_one();\n    request_stream_callback_cv_.notify_one();\n  }\n\n  /************** Query/Profile/Debug **************/\n\n  GenerationConfig GetDefaultGenerationConfig() const final {\n    TVM_FFI_ICHECK(default_generation_config_.defined())\n        << \"The default generation config has not been set.\";\n    return default_generation_config_.value();\n  }\n\n  Request CreateRequest(String id, Array<Data> inputs, String generation_cfg_json_str) const {\n    json::Object config = json::ParseToJSONObject(generation_cfg_json_str);\n    auto gen_config = GenerationConfig::FromJSON(config, GetDefaultGenerationConfig());\n    TVM_FFI_ICHECK(gen_config.IsOk()) << gen_config.UnwrapErr();\n    return Request(std::move(id), std::move(inputs), gen_config.Unwrap());\n  }\n\n  EngineConfig GetCompleteEngineConfig() const final {\n    TVM_FFI_ICHECK(complete_engine_config_.defined()) << \"The engine config has not been set.\";\n    return complete_engine_config_.value();\n  }\n\n  String GetCompleteEngineConfigJSONString() const {\n    return GetCompleteEngineConfig()->AsJSONString();\n  }\n\n  void DebugCallFuncOnAllAllWorker(const String& func_name, Optional<String> func_args) final {\n    bool need_notify = false;\n    {\n      std::lock_guard<std::mutex> lock(background_loop_mutex_);\n      instruction_queue_.emplace_back(InstructionKind::kDebugCallFuncOnAllAllWorker,\n                                      Array<Any>{func_name, func_args});\n      ++pending_request_operation_cnt_;\n      need_notify = engine_waiting_;\n    }\n    if (need_notify) {\n      background_loop_cv_.notify_one();\n    }\n  }\n\n private:\n  void EngineReloadImpl(const std::string& engine_config_json_str) {\n    auto frequest_stream_callback_wrapper = [this](Array<RequestStreamOutput> delta_outputs) {\n      bool need_notify = false;\n      {\n        std::lock_guard<std::mutex> lock(request_stream_callback_mutex_);\n        request_stream_callback_inputs_.push_back(std::move(delta_outputs));\n        ++pending_request_stream_callback_cnt_;\n        need_notify = stream_callback_waiting_;\n      }\n      if (need_notify) {\n        request_stream_callback_cv_.notify_one();\n      }\n    };\n\n    FRequestStreamCallback request_stream_callback(frequest_stream_callback_wrapper);\n    Result<EngineCreationOutput> output_res =\n        Engine::Create(engine_config_json_str, device_, request_stream_callback, trace_recorder_);\n    TVM_FFI_ICHECK(output_res.IsOk()) << output_res.UnwrapErr();\n    EngineCreationOutput output = output_res.Unwrap();\n    background_engine_ = std::move(output.reloaded_engine);\n    default_generation_config_ = output.default_generation_cfg;\n    complete_engine_config_ = output.completed_engine_config;\n    {\n      // Wake up the thread waiting for reload finish.\n      std::lock_guard<std::mutex> lock(reload_unload_mutex_);\n      reload_finished_ = true;\n    }\n    reload_unload_cv_.notify_one();\n  }\n\n  void EngineUnloadImpl() {\n    if (background_engine_ != nullptr) {\n      background_engine_->AbortAllRequests();\n      background_engine_ = nullptr;\n      // Clear the allocated memory in cached memory pool.\n      static Function fclear_memory_manager =\n          Function::GetGlobalRequired(\"vm.builtin.memory_manager.clear\");\n      fclear_memory_manager();\n      default_generation_config_ = std::nullopt;\n      complete_engine_config_ = std::nullopt;\n    }\n    {\n      // Wake up the thread waiting for unload finish.\n      std::lock_guard<std::mutex> lock(reload_unload_mutex_);\n      unload_finished_ = true;\n    }\n    reload_unload_cv_.notify_one();\n  }\n\n  /*! \\brief The device to run models on. */\n  Device device_;\n  /*! \\brief The background normal engine for request processing. */\n  std::unique_ptr<Engine> background_engine_;\n  /*! \\brief The request stream callback. */\n  Function request_stream_callback_;\n  /*! \\brief Event trace recorder. */\n  Optional<EventTraceRecorder> trace_recorder_;\n\n  /*! \\brief complete engine config. */\n  Optional<EngineConfig> complete_engine_config_;\n  /*! \\brief The default generation config. */\n  Optional<GenerationConfig> default_generation_config_;\n\n  /*! \\brief The mutex ensuring only one thread can access critical regions. */\n  std::mutex background_loop_mutex_;\n  std::mutex request_stream_callback_mutex_;\n  std::mutex reload_unload_mutex_;\n  /*! \\brief The condition variable preventing threaded engine from spinning. */\n  std::condition_variable background_loop_cv_;\n  std::condition_variable request_stream_callback_cv_;\n  std::condition_variable reload_unload_cv_;\n  /*! \\brief A boolean flag denoting if the engine needs to exit background loop. */\n  std::atomic<bool> exit_now_ = false;\n\n  /************** Critical Regions **************/\n  /*!\n   * \\brief The instruction queue for the threaded engine.\n   * The instructions include:\n   *  - requests to add into the background engine,\n   *  - requests to abort from the background engine,\n   *  - engine unload/reload,\n   *  - and other debugging instructions.\n   * Elements are sended from other threads and consumed by\n   * the threaded engine in the background loop.\n   */\n  std::vector<std::pair<InstructionKind, Any>> instruction_queue_;\n  /*!\n   * \\brief The delta outputs to pass through callback.\n   * Elements are sended from the background loop thread and\n   * consumed by the foreground thread.\n   */\n  std::vector<Array<RequestStreamOutput>> request_stream_callback_inputs_;\n  /*!\n   * \\brief Number of pending request operations, should be the size of\n   * `requests_to_add_` and `requests_to_abort_`.\n   */\n  std::atomic<int> pending_request_operation_cnt_ = 0;\n  /*!\n   * \\brief Number of pending request stream callback invocations.\n   * It should be the size of `request_stream_callback_inputs_`.\n   */\n  std::atomic<int> pending_request_stream_callback_cnt_ = 0;\n  /*! \\brief A boolean flag indicating if the engine is waiting for new requests/aborts. */\n  bool engine_waiting_ = false;\n  /*! \\brief A boolean flag indicating if the stream callback loop is waiting. */\n  bool stream_callback_waiting_ = false;\n  /*! \\brief A boolean indicating if the engine reload has finished. */\n  bool reload_finished_ = false;\n  /*! \\brief A boolean indicating if the engine unload has finished. */\n  bool unload_finished_ = false;\n};\n\n/*! \\brief The implementation of ThreadedEngine. */\nclass ThreadedEngineModule : public ThreadedEngineImpl, public ffi::ModuleObj {\n public:\n  TVM_MODULE_VTABLE_BEGIN(\"mlc.serve.async_threaded_engine\");\n  TVM_MODULE_VTABLE_ENTRY(\"init_threaded_engine\", &ThreadedEngineImpl::InitThreadedEngine);\n  TVM_MODULE_VTABLE_ENTRY(\"reload\", &ThreadedEngineImpl::Reload);\n  TVM_MODULE_VTABLE_ENTRY(\"add_request\", &ThreadedEngineImpl::AddRequest);\n  TVM_MODULE_VTABLE_ENTRY(\"create_request\", &ThreadedEngineImpl::CreateRequest);\n  TVM_MODULE_VTABLE_ENTRY(\"abort_request\", &ThreadedEngineImpl::AbortRequest);\n  TVM_MODULE_VTABLE_ENTRY(\"run_background_loop\", &ThreadedEngineImpl::RunBackgroundLoop);\n  TVM_MODULE_VTABLE_ENTRY(\"run_background_stream_back_loop\",\n                          &ThreadedEngineImpl::RunBackgroundStreamBackLoop);\n  TVM_MODULE_VTABLE_ENTRY(\"exit_background_loop\", &ThreadedEngineImpl::ExitBackgroundLoop);\n  TVM_MODULE_VTABLE_ENTRY(\"get_complete_engine_config\",\n                          &ThreadedEngineImpl::GetCompleteEngineConfigJSONString);\n  TVM_MODULE_VTABLE_ENTRY(\"reset\", &ThreadedEngineImpl::Reset);\n  TVM_MODULE_VTABLE_ENTRY(\"debug_call_func_on_all_worker\",\n                          &ThreadedEngineImpl::DebugCallFuncOnAllAllWorker);\n  TVM_MODULE_VTABLE_END();\n};\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  namespace refl = tvm::ffi::reflection;\n  refl::GlobalDef().def(\"mlc.serve.create_threaded_engine\",\n                        []() { return Module(tvm::ffi::make_object<ThreadedEngineModule>()); });\n}\n\nstd::unique_ptr<ThreadedEngine> ThreadedEngine::Create() {\n  std::unique_ptr<ThreadedEngineImpl> threaded_engine = std::make_unique<ThreadedEngineImpl>();\n  return std::move(threaded_engine);\n}\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/serve/threaded_engine.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file serve/threaded_engine.h\n * \\brief The header of threaded serving engine in MLC LLM.\n */\n#ifndef MLC_LLM_SERVE_THREADED_ENGINE_H_\n#define MLC_LLM_SERVE_THREADED_ENGINE_H_\n\n#include \"data.h\"\n#include \"engine.h\"\n#include \"request.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace serve {\n\nusing namespace tvm::runtime;\n\n/*!\n * \\brief The interface threaded engine in MLC LLM.\n * The threaded engine keeps running a background request processing\n * loop on a standalone thread. Ensuring thread safety, it exposes\n * `AddRequest` and `AbortRequest` to receive new requests or\n * abortions from other threads, and the internal request processing\n * is backed by a normal engine wrapped inside.\n */\nclass ThreadedEngine {\n public:\n  /*! \\brief Create a ThreadedEngine. */\n  static std::unique_ptr<ThreadedEngine> Create();\n\n  virtual ~ThreadedEngine() = default;\n\n  /*!\n   * \\brief Initialize the threaded engine from packed arguments in PackedArgs.\n   * \\param device The device where to run models.\n   * \\param request_stream_callback The request stream callback function to.\n   * \\param trace_recorder Event trace recorder for requests.\n   */\n  virtual void InitThreadedEngine(Device device, Optional<Function> request_stream_callback,\n                                  Optional<EventTraceRecorder> trace_recorder) = 0;\n\n  /*!\n   * \\brief Reload the engine with the new engine config.\n   * \\param engine_config_json_str The engine config JSON string.\n   */\n  virtual void Reload(String engine_config_json_str) = 0;\n\n  /*! \\brief Unload the background engine. */\n  virtual void Unload() = 0;\n\n  /*! \\brief Reset the engine to the initial state. */\n  virtual void Reset() = 0;\n\n  /*! \\brief Starts the background request processing loop. */\n  virtual void RunBackgroundLoop() = 0;\n\n  /*! \\brief Starts the request stream callback loop. */\n  virtual void RunBackgroundStreamBackLoop() = 0;\n\n  /*!\n   * \\brief Notify the ThreadedEngine to exit the background\n   * request processing loop. This method is invoked by threads\n   * other than the engine-driving thread.\n   */\n  virtual void ExitBackgroundLoop() = 0;\n\n  /*! \\brief Add a new request to the engine. */\n  virtual void AddRequest(Request request) = 0;\n\n  /*! \\brief Abort the input request (specified by id string) from engine. */\n  virtual void AbortRequest(const String& request_id) = 0;\n\n  /************** Query/Profile/Debug **************/\n\n  /*! \\brief Return the default generation config. */\n  virtual GenerationConfig GetDefaultGenerationConfig() const = 0;\n\n  /*! \\brief Return the complete engine config. */\n  virtual EngineConfig GetCompleteEngineConfig() const = 0;\n\n  /*! \\brief Call the given global function on all workers. Only for debug purpose. */\n  virtual void DebugCallFuncOnAllAllWorker(const String& func_name, Optional<String> func_args) = 0;\n};\n\n}  // namespace serve\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SERVE_THREADED_ENGINE_H_\n"
  },
  {
    "path": "cpp/support/debug_utils.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file support/debug_utils.h\n * \\brief Tools for debug purposes.\n */\n#ifndef MLC_LLM_SUPPORT_DEBUG_UTILS_H_\n#define MLC_LLM_SUPPORT_DEBUG_UTILS_H_\n\n#include \"../tokenizers/tokenizers.h\"\n\nnamespace mlc {\nnamespace llm {\n\n/*! \\brief A registry for debug information. */\nclass DebugRegistry {\n public:\n  static DebugRegistry* Global() {\n    static DebugRegistry reg;\n    return &reg;\n  }\n\n  // Tokenizer information, helpful for converting token id to token string in debugging\n  Tokenizer tokenizer;\n};\n\n/*! \\brief Register the tokenizer to the global tokenizer registry. */\ninline void DebugRegisterTokenizer(const Tokenizer& tokenizer) {\n  DebugRegistry::Global()->tokenizer = tokenizer;\n}\n\n/*! \\brief Get the registered tokenizer from the global tokenizer registry. */\ninline Tokenizer DebugGetTokenizer() { return DebugRegistry::Global()->tokenizer; }\n\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SUPPORT_DEBUG_UTILS_H_\n"
  },
  {
    "path": "cpp/support/dynamic_bitset.h",
    "content": "/*!\n * Copyright (c) 2023-2025 by Contributors\n * \\file support/dynamic_bitset.h\n * \\brief The header for utilities used in grammar-guided generation.\n */\n#ifndef MLC_LLM_SUPPORT_DYNAMIC_BITSET_H_\n#define MLC_LLM_SUPPORT_DYNAMIC_BITSET_H_\n\n#include <tvm/runtime/logging.h>\n\n#include <cstdint>\n#include <cstring>\n#include <vector>\n\nnamespace mlc {\nnamespace llm {\n\n/*!\n * \\brief A bitset whose length is specified at runtime. Note the size cannot be changed after\n * construction.\n * \\details The buffer of the bitset is a uint32_t array. There are two uses for this class:\n * - When passing nullptr to data, it maintains an internal buffer for the bitset.\n * - When passing a pointer to a buffer with enough size, it uses the external buffer for the\n *   bitset.\n */\nclass DynamicBitset {\n public:\n  /*!\n   * \\brief Calculate the minimal size of the uint32_t buffer for the bitset with the given size.\n   * \\param element_size The size of the bitset.\n   * \\return The minimal buffer size.\n   */\n  static int CalculateBufferSize(int element_size) { return (element_size + 31) / 32; }\n\n  /*!\n   * \\brief Construct a empty bitset. This object should be assigned to a valid bitset before using.\n   */\n  DynamicBitset() : size_(0), buffer_size_(0), data_(nullptr), is_internal_(true) {}\n\n  /*!\n   * \\brief Construct a bitset with the given size.\n   * \\param size The size of the bitset.\n   * \\param data The buffer for the bitset. If nullptr, the bitset will maintain an internal buffer.\n   */\n  DynamicBitset(int size, uint32_t* data = nullptr)\n      : size_(size), buffer_size_(CalculateBufferSize(size)) {\n    if (data == nullptr) {\n      internal_buffer_.resize(buffer_size_, 0);\n      data_ = internal_buffer_.data();\n      is_internal_ = true;\n    } else {\n      data_ = data;\n      is_internal_ = false;\n    }\n  }\n\n  /*! \\brief Copy assignment. */\n  DynamicBitset& operator=(const DynamicBitset& other) {\n    TVM_FFI_DCHECK(is_internal_ || size_ >= other.size_)\n        << \"Expanding bitset size is not allowed when the \"\n           \"memory of the bitset is externally managed\";\n    size_ = other.size_;\n    buffer_size_ = other.buffer_size_;\n    if (is_internal_) {\n      internal_buffer_.reserve(buffer_size_);\n      data_ = internal_buffer_.data();\n    }\n    if (data_ != other.data_) {\n      std::memcpy(data_, other.data_, buffer_size_ * sizeof(uint32_t));\n    }\n    return *this;\n  }\n\n  /*! \\brief Move assignment. */\n  DynamicBitset& operator=(DynamicBitset&& other) {\n    size_ = other.size_;\n    buffer_size_ = other.buffer_size_;\n    is_internal_ = other.is_internal_;\n    if (is_internal_) {\n      internal_buffer_ = std::move(other.internal_buffer_);\n      data_ = internal_buffer_.data();\n    } else {\n      data_ = other.data_;\n    }\n    return *this;\n  }\n\n  /*! \\brief Get the value of the bit at the given index. */\n  bool operator[](int index) const {\n    TVM_FFI_DCHECK(data_ && index >= 0 && index < size_);\n    return (data_[index / 32] >> (index % 32)) & 1;\n  }\n\n  /*! \\brief Get the size of the bitset. */\n  int Size() const { return size_; }\n\n  /*! \\brief Set the whole bitset to true. */\n  void Set() {\n    TVM_FFI_DCHECK(data_);\n    std::memset(data_, 0xFF, buffer_size_ * sizeof(uint32_t));\n  }\n\n  /*! \\brief Set the bit at the given index to the given value. */\n  void Set(int index, bool value = true) {\n    TVM_FFI_DCHECK(data_ && index >= 0 && index < size_);\n    if (value) {\n      data_[index / 32] |= 1 << (index % 32);\n    } else {\n      data_[index / 32] &= ~(1 << (index % 32));\n    }\n  }\n\n  /*! \\brief Set the whole bitset to false. */\n  void Reset() {\n    TVM_FFI_DCHECK(data_);\n    std::memset(data_, 0, buffer_size_ * sizeof(uint32_t));\n  }\n\n  /*! \\brief Set the bit at the given index to false. */\n  void Reset(int index) { Set(index, false); }\n\n  /*! \\brief Perform a bitwise OR operation between the current bitset and another bitset. */\n  DynamicBitset& operator|=(const DynamicBitset& other) {\n    TVM_FFI_DCHECK(buffer_size_ <= other.buffer_size_);\n    for (int i = 0; i < buffer_size_; ++i) {\n      data_[i] |= other.data_[i];\n    }\n    return *this;\n  }\n\n private:\n  // The size of the bitset.\n  int size_;\n  // The size of the buffer.\n  int buffer_size_;\n  // The buffer for the bitset.\n  uint32_t* data_;\n  // The internal buffer. It is empty if not needed.\n  std::vector<uint32_t> internal_buffer_;\n  // Whether the buffer is internally managed.\n  bool is_internal_;\n};\n\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SUPPORT_DYNAMIC_BITSET_H_\n"
  },
  {
    "path": "cpp/support/encoding.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file support/encoding.cc\n */\n#include \"encoding.h\"\n\n#include <tvm/runtime/logging.h>\n\n#include <array>\n\nnamespace mlc {\nnamespace llm {\n\nstd::string PrintAsUTF8(TCodepoint codepoint) {\n  TVM_FFI_ICHECK(codepoint <= 0x10FFFF) << \"Invalid codepoint: \" << codepoint;\n  std::string utf8;\n  if (codepoint <= 0x7F) {\n    // 1-byte sequence\n    utf8 += static_cast<char>(codepoint);\n  } else if (codepoint <= 0x7FF) {\n    // 2-byte sequence\n    utf8 += static_cast<char>(0xC0 | ((codepoint >> 6) & 0x1F));\n    utf8 += static_cast<char>(0x80 | (codepoint & 0x3F));\n  } else if (codepoint <= 0xFFFF) {\n    // 3-byte sequence\n    utf8 += static_cast<char>(0xE0 | ((codepoint >> 12) & 0x0F));\n    utf8 += static_cast<char>(0x80 | ((codepoint >> 6) & 0x3F));\n    utf8 += static_cast<char>(0x80 | (codepoint & 0x3F));\n  } else {\n    // 4-byte sequence\n    utf8 += static_cast<char>(0xF0 | ((codepoint >> 18) & 0x07));\n    utf8 += static_cast<char>(0x80 | ((codepoint >> 12) & 0x3F));\n    utf8 += static_cast<char>(0x80 | ((codepoint >> 6) & 0x3F));\n    utf8 += static_cast<char>(0x80 | (codepoint & 0x3F));\n  }\n  return utf8;\n}\n\nstd::string PrintAsEscaped(\n    TCodepoint codepoint,\n    const std::unordered_map<TCodepoint, std::string>& additional_escape_map) {\n  static const std::unordered_map<TCodepoint, std::string> kCodepointToEscape = {\n      {'\\'', \"\\\\\\'\"}, {'\\\"', \"\\\\\\\"\"}, {'\\?', \"\\\\\\?\"}, {'\\\\', \"\\\\\\\\\"}, {'\\a', \"\\\\a\"},\n      {'\\b', \"\\\\b\"},  {'\\f', \"\\\\f\"},  {'\\n', \"\\\\n\"},  {'\\r', \"\\\\r\"},  {'\\t', \"\\\\t\"},\n      {'\\v', \"\\\\v\"},  {'\\0', \"\\\\0\"},  {'\\x1B', \"\\\\e\"}};\n\n  if (auto it = additional_escape_map.find(codepoint); it != additional_escape_map.end()) {\n    return it->second;\n  }\n\n  if (auto it = kCodepointToEscape.find(codepoint); it != kCodepointToEscape.end()) {\n    return it->second;\n  }\n\n  if (codepoint >= 0x20 && codepoint <= 0x7E) {\n    return std::string({static_cast<char>(codepoint)});\n  }\n\n  // convert codepoint to hex\n  char prefix = codepoint <= 0xFF ? 'x' : codepoint <= 0xFFFF ? 'u' : 'U';\n  int width = codepoint <= 0xFF ? 2 : codepoint <= 0xFFFF ? 4 : 8;\n  std::stringstream ss;\n  ss << std::setfill('0') << std::setw(width) << std::hex << codepoint;\n  auto hex = ss.str();\n  return std::string(\"\\\\\") + prefix + hex;\n}\n\nstd::string PrintAsEscaped(uint8_t raw_char) { return PrintAsEscaped(raw_char); }\n\nstd::string PrintAsEscaped(std::string raw_str) {\n  std::string res;\n  auto codepoints = ParseUTF8(raw_str.c_str(), UTF8ErrorPolicy::kReturnByte);\n  for (auto c : codepoints) {\n    res += PrintAsEscaped(c);\n  }\n  return res;\n}\n\nstd::tuple<bool, int, TCodepoint> HandleUTF8FirstByte(uint8_t byte) {\n  static const std::array<int8_t, 5> kFirstByteMask = {0x00, 0x7F, 0x1F, 0x0F, 0x07};\n  // clang-format off\n  static const std::array<int, 256> kUtf8Bytes = {\n     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n     1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,\n    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n     2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,\n     2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,\n     3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,\n     4,  4,  4,  4,  4,  4,  4,  4, -1, -1, -1, -1, -1, -1, -1, -1,\n  };\n  // clang-format on\n  auto num_bytes = kUtf8Bytes[static_cast<uint8_t>(byte)];\n  if (num_bytes == -1) {\n    return {false, 0, 0};\n  }\n  return {true, num_bytes, byte & kFirstByteMask[num_bytes]};\n}\n\nstd::pair<TCodepoint, const char*> ParseNextUTF8(const char* utf8, UTF8ErrorPolicy error_policy) {\n  auto [accepted, num_bytes, res] = HandleUTF8FirstByte(utf8[0]);\n  if (accepted) {\n    for (int i = 1; i < num_bytes; ++i) {\n      if (utf8[i] == 0 || (static_cast<uint8_t>(utf8[i]) & 0xC0) != 0x80) {\n        // invalid utf8\n        accepted = false;\n        break;\n      }\n      res = (res << 6) | (static_cast<uint8_t>(utf8[i]) & 0x3F);\n    }\n  }\n\n  if (!accepted) {\n    // invalid utf8\n    if (error_policy == UTF8ErrorPolicy::kReturnInvalid) {\n      return {CharHandlingError::kInvalidUTF8, utf8};\n    } else {\n      return {static_cast<unsigned char>(utf8[0]), utf8 + 1};\n    }\n  }\n\n  return {res, utf8 + num_bytes};\n}\n\nstd::vector<TCodepoint> ParseUTF8(const char* utf8, UTF8ErrorPolicy error_policy) {\n  std::vector<TCodepoint> codepoints;\n  while (*utf8 != 0) {\n    TCodepoint codepoint;\n    std::tie(codepoint, utf8) = ParseNextUTF8(utf8, error_policy);\n    if (codepoint == CharHandlingError::kInvalidUTF8) {\n      return {codepoint};\n    }\n    codepoints.push_back(codepoint);\n  }\n  return codepoints;\n}\n\ninline int HexCharToInt(char c) {\n  if (c >= '0' && c <= '9') {\n    return c - '0';\n  } else if (c >= 'a' && c <= 'f') {\n    return c - 'a' + 10;\n  } else if (c >= 'A' && c <= 'F') {\n    return c - 'A' + 10;\n  } else {\n    return -1;\n  }\n}\n\nstd::pair<TCodepoint, const char*> ParseNextUTF8OrEscaped(\n    const char* utf8, const std::unordered_map<std::string, TCodepoint>& additional_escape_map) {\n  static const std::unordered_map<std::string, TCodepoint> kEscapeToCodepoint = {\n      {\"\\\\\\'\", '\\''}, {\"\\\\\\\"\", '\\\"'}, {\"\\\\\\?\", '\\?'}, {\"\\\\\\\\\", '\\\\'}, {\"\\\\a\", '\\a'},\n      {\"\\\\b\", '\\b'},  {\"\\\\f\", '\\f'},  {\"\\\\n\", '\\n'},  {\"\\\\r\", '\\r'},  {\"\\\\t\", '\\t'},\n      {\"\\\\v\", '\\v'},  {\"\\\\0\", '\\0'},  {\"\\\\e\", '\\x1B'}};\n  if (utf8[0] != '\\\\') {\n    return ParseNextUTF8(utf8, UTF8ErrorPolicy::kReturnInvalid);\n  }\n\n  auto escape_sequence = std::string(utf8, 2);\n  if (auto it = additional_escape_map.find(escape_sequence); it != additional_escape_map.end()) {\n    return {it->second, utf8 + 2};\n  }\n  if (auto it = kEscapeToCodepoint.find(escape_sequence); it != kEscapeToCodepoint.end()) {\n    return {it->second, utf8 + 2};\n  }\n\n  if (utf8[1] == 'x') {\n    // arbitrary length hex\n    int len = 0;\n    int32_t codepoint = 0;\n    while (true) {\n      auto digit = HexCharToInt(utf8[2 + len]);\n      if (digit == -1) {\n        break;\n      }\n      codepoint = codepoint * 16 + digit;\n      ++len;\n    }\n    if (len == 0) {\n      return {CharHandlingError::kInvalidEscape, utf8};\n    }\n    return {codepoint, utf8 + len + 2};\n  } else if (utf8[1] == 'u' || utf8[1] == 'U') {\n    // 4- or 8-digit hex\n    int len = utf8[1] == 'u' ? 4 : 8;\n    int32_t codepoint = 0;\n\n    for (int i = 0; i < len; ++i) {\n      auto digit = HexCharToInt(utf8[i + 2]);\n      if (digit == -1) {\n        return {CharHandlingError::kInvalidEscape, utf8};\n      }\n      codepoint = codepoint * 16 + digit;\n    }\n    return {codepoint, utf8 + len + 2};\n  } else {\n    return {CharHandlingError::kInvalidEscape, utf8};\n  }\n}\n\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/support/encoding.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file support/encoding.h\n * \\brief Encoding and decoding from/to UTF-8 and escape sequence to/from codepoints.\n */\n#ifndef MLC_LLM_SUPPORT_ENCODING_H_\n#define MLC_LLM_SUPPORT_ENCODING_H_\n\n#include <cstdint>\n#include <string>\n#include <unordered_map>\n#include <vector>\n\nnamespace mlc {\nnamespace llm {\n\n/*! \\brief Represents a unicode codepoint. */\nusing TCodepoint = int32_t;\n\n/*!\n * \\brief Handle the utf-8 first byte.\n * \\returns (is_valid, total_number_of_bytes, initial_codepoint).\n */\nstd::tuple<bool, int, TCodepoint> HandleUTF8FirstByte(uint8_t byte);\n\n/*!\n * \\brief Print a codepoint to a UTF-8 string.\n * \\param codepoint The codepoint.\n * \\return The UTF-8 string.\n */\nstd::string PrintAsUTF8(TCodepoint codepoint);\n\n/*!\n * \\brief Print a codepoint to a escaped string. If the codepoint is not printable, it will be\n * escaped. By default the function support escape sequences in C (\"\\n\", \"\\t\", \"\\u0123\"). User can\n * specify more escape sequences using additional_escape_map.\n * \\param codepoint The codepoint.\n * \\param additional_escape_map A map from codepoint to escape sequence. If the codepoint is in the\n * map, it will be escaped using the corresponding escape sequence. e.g. {{'-', \"\\\\-\"}}. \\return The\n * printable string.\n */\nstd::string PrintAsEscaped(\n    TCodepoint codepoint,\n    const std::unordered_map<TCodepoint, std::string>& additional_escape_map = {});\n\n/*!\n * \\brief Print the given char to a escaped string that can be printed.\n * \\return The escaped string.\n */\nstd::string PrintAsEscaped(uint8_t raw_char);\n\n/*!\n * \\brief Print the given string to a escaped string that can be printed.\n * \\return The escaped string.\n */\nstd::string PrintAsEscaped(std::string raw_str);\n\n/*!\n * \\brief Represents an error when handling characters. Will be returned as a special TCodepoint\n * value.\n */\nenum CharHandlingError : TCodepoint {\n  /*! \\brief The UTF-8 string is invalid. */\n  kInvalidUTF8 = -10,\n  /*! \\brief The escape sequence is invalid. */\n  kInvalidEscape = -11,\n};\n\n/*!\n * \\brief The method to handle invalid UTF-8 sequence.\n */\nenum class UTF8ErrorPolicy {\n  /*! \\brief Return an error codepoint when an error is encountered. */\n  kReturnInvalid,\n  /*! \\brief Skip the error and continue parsing. */\n  kReturnByte,\n};\n\n/*!\n * \\brief Parse the first codepoint in a UTF-8 string.\n * \\param utf8 The UTF-8 string.\n * \\return The codepoint and new pointer. If the UTF-8 string is invalid, and the error policy is\n * kReturnInvalid, the function returns (CharHandlingError::kInvalidUTF8, input char pointer).\n */\nstd::pair<TCodepoint, const char*> ParseNextUTF8(\n    const char* utf8, UTF8ErrorPolicy error_policy = UTF8ErrorPolicy::kReturnInvalid);\n\n/*!\n * \\brief Parse all codepoints in a UTF-8 string.\n * \\param utf8 The UTF-8 string.\n * \\return All codepoints. If the UTF-8 string is invalid, and the error policy is\n * kReturnInvalid, the function returns {CharHandlingError::kInvalidUTF8}.\n */\nstd::vector<TCodepoint> ParseUTF8(const char* utf8,\n                                  UTF8ErrorPolicy error_policy = UTF8ErrorPolicy::kReturnInvalid);\n\n/*!\n * \\brief Parse the first codepoint from a UTF-8 string. Also checks escape sequences and converts\n * the escaped char to its original value.\n * \\param utf8 The UTF-8 string or the escape sequence.\n * \\param additional_escape_map A map from escape sequence to codepoint. If the escape sequence is\n * in the map, it will be converted to the corresponding codepoint. e.g. {{\"\\\\-\", '-'}}.\n * \\return The codepoint and the new pointer. If the UTF-8 string or the escape sequence is\n * invalid, and the error policy is kReturnInvalid, the function returns\n * (CharHandlingError::kInvalidUTF8, input char pointer).\n */\nstd::pair<TCodepoint, const char*> ParseNextUTF8OrEscaped(\n    const char* utf8,\n    const std::unordered_map<std::string, TCodepoint>& additional_escape_map = {});\n\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SUPPORT_ENCODING_H_\n"
  },
  {
    "path": "cpp/support/json_parser.h",
    "content": "/*!\n * \\file support/json_parser.h\n * \\brief Helps to parse JSON strings and objects.\n */\n#ifndef MLC_LLM_SUPPORT_JSON_PARSER_H_\n#define MLC_LLM_SUPPORT_JSON_PARSER_H_\n\n#include <tvm/ffi/container/shape.h>\n#include <tvm/ffi/extra/json.h>\n#include <tvm/ffi/string.h>\n#include <tvm/runtime/data_type.h>\n#include <tvm/runtime/logging.h>\n\n#include <optional>\n\n#include \"result.h\"\n\nnamespace mlc {\nnamespace llm {\nnamespace json {\n\nusing ::tvm::ffi::json::Array;\nusing ::tvm::ffi::json::Object;\nusing ::tvm::ffi::json::Value;\n\n/*!\n * \\brief Parse a JSON string to a JSON object.\n * \\param json_str The JSON string to parse.\n * \\return The parsed JSON object.\n */\ninline Object ParseToJSONObject(const std::string& json_str) {\n  tvm::ffi::String err;\n  Value result = ::tvm::ffi::json::Parse(json_str, &err);\n  TVM_FFI_CHECK(err.empty(), ValueError)\n      << \"Failed to parse JSON: err. The JSON string is:\" << json_str;\n  auto opt = result.try_cast<Object>();\n  TVM_FFI_CHECK(opt.has_value(), ValueError)\n      << \"The given string is not a JSON object: \" << json_str;\n  return *opt;\n}\n/*!\n * \\brief Parse a JSON string to a JSON object.\n * \\param json_str The JSON string to parse.\n * \\return The parsed JSON object, or the error message.\n */\ninline Result<Object> ParseToJSONObjectWithResultReturn(const std::string& json_str) {\n  using TResult = Result<Object>;\n  tvm::ffi::String err;\n  Value result = ::tvm::ffi::json::Parse(json_str, &err);\n  if (!err.empty()) {\n    return TResult::Error(\"Failed to parse JSON: err. The JSON string is: \" + json_str +\n                          \". The error is \" + std::string(err));\n  }\n  auto opt = result.try_cast<Object>();\n  if (!opt.has_value()) {\n    return TResult::Error(\"ValueError: The given string is not a JSON object: \" + json_str);\n  }\n  return TResult::Ok(*opt);\n}\n\n/*!\n * \\brief Lookup a JSON object by a key, and convert it to a given type.\n * \\param json The JSON object to look up.\n * \\param key The key to look up.\n * \\tparam ValueType The type to be converted to.\n * \\return The converted value.\n */\ntemplate <typename ValueType>\nValueType Lookup(const Object& json, const std::string& key);\n/*!\n * \\brief Lookup a JSON array by an index, and convert it to a given type.\n * \\param json The JSON array to look up.\n * \\param index The index to look up.\n * \\tparam ValueType The type to be converted to.\n * \\return The converted value.\n */\ntemplate <typename ValueType>\nValueType Lookup(const Array& json, int index);\n/*!\n * \\brief Lookup a JSON object by a key, and convert it to a given type.\n * If the key doesn't exist or has null value, the default value is returned.\n * \\param json The JSON object to look up.\n * \\param key The key to look up.\n * \\tparam ValueType The type to be converted to.\n * \\return The converted value, or the default value if the key doesn't exist or has null value.\n */\ntemplate <typename ValueType>\ninline ValueType LookupOrDefault(const Object& json, const std::string& key,\n                                 const ValueType& default_value) {\n  if (json.count(key) == 0 || json.at(key) == nullptr) {\n    return default_value;\n  }\n  auto opt = json.at(key).try_cast<ValueType>();\n  TVM_FFI_CHECK(opt.has_value(), ValueError) << \"key `\" << key << \"` has unexpected type\";\n  return *opt;\n}\n/*!\n * \\brief Lookup a JSON object by a key, and convert it to a given type.\n * If the key doesn't exist or has null value, return std::nullopt.\n * \\param json The JSON object to look up.\n * \\param key The key to look up.\n * \\tparam ValueType The type to be converted to.\n * \\return The converted value, or std::nullopt if the value doesn't exist or has null value.\n */\ntemplate <typename ValueType>\ninline std::optional<ValueType> LookupOptional(const Object& json, const std::string& key) {\n  if (json.count(key) == 0 || json.at(key) == nullptr) {\n    return std::nullopt;\n  }\n  auto opt = json.at(key).try_cast<ValueType>();\n  TVM_FFI_CHECK(opt.has_value(), ValueError) << \"key `\" << key << \"` has unexpected type\";\n  return *opt;\n}\n/*!\n * \\brief Lookup a JSON object by a key, and convert it to a given type.\n * \\param json The JSON object to look up.\n * \\param key The key to look up.\n * \\tparam ValueType The type to be converted to.\n * \\return The converted value, or the error message.\n */\ntemplate <typename ValueType>\ninline Result<ValueType> LookupWithResultReturn(const Object& json, const std::string& key) {\n  using TResult = Result<ValueType>;\n  if (json.count(key) == 0) {\n    return TResult::Error(\"ValueError: key \\\"\" + key + \"\\\" not found in the JSON object\");\n  }\n  auto opt = json.at(key).try_cast<ValueType>();\n  if (!opt.has_value()) {\n    return TResult::Error(\"ValueError: key \\\"\" + key + \"\\\" has unexpected value type.\");\n  }\n  return TResult::Ok(*opt);\n}\n/*!\n * \\brief Lookup a JSON object by a key, and convert it to a given type.\n * If the key doesn't exist or has null value, the default value is returned.\n * \\param json The JSON object to look up.\n * \\param key The key to look up.\n * \\tparam ValueType The type to be converted to.\n * \\return The converted value, or the default value if the key doesn't exist or has null value\n * , or the error message.\n */\ntemplate <typename ValueType>\ninline Result<ValueType> LookupOrDefaultWithResultReturn(const Object& json, const std::string& key,\n                                                         const ValueType& default_value) {\n  using TResult = Result<ValueType>;\n  if (json.count(key) == 0 || json.at(key) == nullptr) {\n    return TResult::Ok(default_value);\n  }\n  auto opt = json.at(key).try_cast<ValueType>();\n  if (!opt.has_value()) {\n    return TResult::Error(\"ValueError: key \\\"\" + key + \"\\\" has unexpected value type.\");\n  }\n  return TResult::Ok(*opt);\n}\n/*!\n * \\brief Lookup a JSON object by a key, and convert it to a given type.\n * If the key doesn't exist or has null value, return std::nullopt.\n * \\param json The JSON object to look up.\n * \\param key The key to look up.\n * \\tparam ValueType The type to be converted to.\n * \\return The converted value, or std::nullopt if the value doesn't exist or has null value,\n * , or the error message.\n */\ntemplate <typename ValueType>\ninline Result<std::optional<ValueType>> LookupOptionalWithResultReturn(const Object& json,\n                                                                       const std::string& key) {\n  using TResult = Result<std::optional<ValueType>>;\n  if (json.count(key) == 0 || json.at(key) == nullptr) {\n    return TResult::Ok(std::nullopt);\n  }\n  auto opt = json.at(key).try_cast<ValueType>();\n  if (!opt.has_value()) {\n    return TResult::Error(\"ValueError: key \\\"\" + key + \"\\\" has unexpected value type.\");\n  }\n  return TResult::Ok(*opt);\n}\n\n// Implementation details\n\n/*! \\brief Shape extension to incorporate symbolic shapes. */\nstruct SymShapeTuple {\n  tvm::ffi::Shape shape_values;\n  std::vector<std::string> sym_names;\n\n  /*! \\brief Convert symbolic shape tuple to static shape tuple with model config. */\n  tvm::ffi::Shape ToStatic(const Object& model_config) {\n    std::vector<int64_t> shape;\n    shape.reserve(shape_values.size());\n    for (int i = 0; i < static_cast<int>(shape_values.size()); ++i) {\n      if (shape_values[i] != -1) {\n        shape.push_back(shape_values[i]);\n      } else {\n        auto opt = model_config.at(sym_names[i]).try_cast<int64_t>();\n        TVM_FFI_CHECK(opt.has_value(), ValueError)\n            << \"model config is expected to contain \\\"\" << sym_names[i]\n            << \"\\\" as an integer. However, the given config has unexpected type for \\\"\"\n            << sym_names[i] << \"\\\".\";\n        shape.push_back(*opt);\n      }\n    }\n    return tvm::ffi::Shape(std::move(shape));\n  }\n};\n\nnamespace details {\n\ninline tvm::runtime::DataType DTypeFromString(const std::string& s) {\n  return tvm::runtime::DataType(tvm::runtime::StringToDLDataType(s));\n}\n\ninline SymShapeTuple SymShapeTupleFromArray(const Array& shape) {\n  std::vector<int64_t> result;\n  std::vector<std::string> sym_names;\n  result.reserve(shape.size());\n  sym_names.reserve(shape.size());\n  for (int i = 0; i < static_cast<int>(shape.size()); ++i) {\n    const auto& dim = shape[i];\n    auto str_opt = dim.try_cast<std::string>();\n    if (str_opt.has_value()) {\n      result.push_back(-1);\n      sym_names.push_back(*str_opt);\n    } else {\n      auto int_opt = dim.try_cast<int64_t>();\n      TVM_FFI_CHECK(int_opt.has_value(), ValueError) << \"shape has unexpected type\";\n      result.push_back(*int_opt);\n      sym_names.push_back(\"\");\n    }\n  }\n  return SymShapeTuple{tvm::ffi::Shape(std::move(result)), sym_names};\n}\n\n}  // namespace details\n\ntemplate <typename ValueType>\ninline ValueType Lookup(const Object& json, const std::string& key) {\n  TVM_FFI_CHECK(json.count(key) != 0, ValueError)\n      << \"key `\" << key << \"` not found in the JSON object\";\n  auto opt = json.at(key).try_cast<ValueType>();\n  TVM_FFI_CHECK(opt.has_value(), ValueError) << \"key `\" << key << \"` has unexpected type\";\n  return *opt;\n}\n\ntemplate <typename ValueType>\ninline ValueType Lookup(const Array& json, int index) {\n  TVM_FFI_ICHECK(index < static_cast<int>(json.size()))\n      << \"IndexError: json::array index out of range\";\n  auto opt = json[index].try_cast<ValueType>();\n  TVM_FFI_ICHECK(opt.has_value()) << \"ValueError: value at index `\" << index\n                                  << \"` has unexpected type\";\n  return *opt;\n}\n\ntemplate <>\ninline tvm::runtime::DataType Lookup(const Object& json, const std::string& key) {\n  return details::DTypeFromString(Lookup<std::string>(json, key));\n}\n\ntemplate <>\ninline tvm::runtime::DataType Lookup(const Array& json, int index) {\n  return details::DTypeFromString(Lookup<std::string>(json, index));\n}\n\ntemplate <>\ninline SymShapeTuple Lookup(const Object& json, const std::string& key) {\n  return details::SymShapeTupleFromArray(Lookup<Array>(json, key));\n}\n\ntemplate <>\ninline SymShapeTuple LookupOrDefault(const Object& json, const std::string& key,\n                                     const SymShapeTuple& default_value) {\n  if (json.count(key) == 0 || json.at(key) == nullptr) {\n    return default_value;\n  }\n  return details::SymShapeTupleFromArray(Lookup<Array>(json, key));\n}\n\ntemplate <>\ninline SymShapeTuple Lookup(const Array& json, int index) {\n  return details::SymShapeTupleFromArray(Lookup<Array>(json, index));\n}\n\n}  // namespace json\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SUPPORT_JSON_PARSER_H_\n"
  },
  {
    "path": "cpp/support/load_bytes_from_file.h",
    "content": "/*!\n * Copyright (c) 2023-2025 by Contributors\n * \\file support/load_bytes_from_file.h\n * \\brief Utility methods to load from files.\n */\n#ifndef MLC_LLM_SUPPORT_LOAD_BYTES_FROM_FILE_H_\n#define MLC_LLM_SUPPORT_LOAD_BYTES_FROM_FILE_H_\n\n#include <tvm/runtime/logging.h>\n\n#include <fstream>\n#include <string>\n\nnamespace mlc {\nnamespace llm {\n\ninline std::string LoadBytesFromFile(const std::string& path) {\n  std::ifstream fs(path, std::ios::in | std::ios::binary);\n  TVM_FFI_ICHECK(!fs.fail()) << \"Cannot open \" << path;\n  std::string data;\n  fs.seekg(0, std::ios::end);\n  size_t size = static_cast<size_t>(fs.tellg());\n  fs.seekg(0, std::ios::beg);\n  data.resize(size);\n  fs.read(data.data(), size);\n  return data;\n}\n\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SUPPORT_LOAD_BYTES_FROM_FILE_H_\n"
  },
  {
    "path": "cpp/support/progress_bar.h",
    "content": "/*!\n * Copyright (c) 2023-2025 by Contributors\n * \\file support/progress_bar.h\n * \\brief A simple progress bar in C++.\n */\n#ifndef MLC_LLM_SUPPORT_PROGRESS_BAR_H_\n#define MLC_LLM_SUPPORT_PROGRESS_BAR_H_\n\n#include <iostream>\n#include <string>\n\nnamespace mlc {\nnamespace llm {\n\nclass ProgressBar {\n public:\n  explicit ProgressBar(int total, int width = 100) : total(total), width(width), cur(0) {}\n\n  void Progress() {\n    if (cur < total) {\n      ++cur;\n    }\n    int bar_width = width - 2;  // Adjust for borders\n    int completed = static_cast<int>(static_cast<float>(cur) / total * bar_width);\n    int remaining = bar_width - completed;\n    std::cout << \"[\"                          //\n              << std::string(completed, '=')  //\n              << \">\"                          //\n              << std::string(remaining, ' ')  //\n              << \"] \"                         //\n              << \" [\" << cur << \"/\" << total << \"]\";\n    if (cur < total) {\n      std::cout << \"\\r\";\n      std::cout.flush();\n    } else {\n      std::cout << std::endl;  // Move to the next line after the progress bar is complete\n    }\n  }\n\n private:\n  int total;\n  int width;\n  int cur;\n};\n\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SUPPORT_PROGRESS_BAR_H_\n"
  },
  {
    "path": "cpp/support/random.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file support/random.h\n * \\brief Header of random number generator.\n */\n\n#ifndef MLC_LLM_SUPPORT_RANDOM_H_\n#define MLC_LLM_SUPPORT_RANDOM_H_\n\n#include <random>\n\nnamespace mlc {\nnamespace llm {\n\n// Random number generator\nclass RandomGenerator {\n private:\n  std::mt19937 gen;\n  std::uniform_real_distribution<> dis;\n\n public:\n  RandomGenerator(int seed = std::random_device{}()) : gen(seed), dis(0.0, 1.0) {}\n\n  static RandomGenerator& GetInstance(int seed = std::random_device{}()) {\n    static RandomGenerator instance(seed);\n    return instance;\n  }\n\n  double GetRandomNumber() { return dis(gen); }\n\n  void SetSeed(int seed) { gen.seed(seed); }\n};\n\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SUPPORT_RANDOM_H_\n"
  },
  {
    "path": "cpp/support/result.h",
    "content": "/*!\n * Copyright (c) 2023-2025 by Contributors\n * \\file support/result.h\n * \\brief The header for the Result class in MLC LLM.\n */\n#ifndef MLC_LLM_SUPPORT_RESULT_H_\n#define MLC_LLM_SUPPORT_RESULT_H_\n\n#include <tvm/runtime/logging.h>\n\n#include <optional>\n#include <string>\n\nnamespace mlc {\nnamespace llm {\n\n/*!\n * \\brief The result class in MLC LLM.\n * Each instance is either an okay value or an error.\n * \\tparam T The okay value type of the result.\n * \\tparam E The error type of the result.\n */\ntemplate <typename T, typename E = std::string>\nclass Result {\n public:\n  /*! \\brief Create a result with an okay value. */\n  static Result Ok(T value) {\n    Result result;\n    result.ok_value_ = std::move(value);\n    return result;\n  }\n  /*! \\brief Create a result with an error value. */\n  static Result Error(E error) {\n    Result result;\n    result.err_value_ = std::move(error);\n    return result;\n  }\n  /*! \\brief Check if the result is okay or not. */\n  bool IsOk() const { return ok_value_.has_value(); }\n  /*! \\brief Check if the result is an error or not. */\n  bool IsErr() const { return err_value_.has_value(); }\n  /*!\n   * \\brief Unwrap the result and return the okay value.\n   * Throwing exception if it is an error.\n   * \\note This function returns the ok value by moving, so a Result can be unwrapped only once.\n   */\n  T Unwrap() {\n    TVM_FFI_ICHECK(ok_value_.has_value()) << \"Cannot unwrap result on an error value.\";\n    TVM_FFI_ICHECK(!unwrapped_) << \"Cannot unwrap a Result instance twice.\";\n    unwrapped_ = true;\n    return std::move(ok_value_.value());\n  }\n  /*!\n   * \\brief Unwrap the result and return the error value.\n   * Throwing exception if it is an okay value.\n   * \\note This function returns the error value by moving, so a Result can be unwrapped only once.\n   */\n  E UnwrapErr() {\n    TVM_FFI_ICHECK(err_value_.has_value()) << \"Cannot unwrap result on an okay value.\";\n    TVM_FFI_ICHECK(!unwrapped_) << \"Cannot unwrap a Result instance twice.\";\n    unwrapped_ = true;\n    return std::move(err_value_.value());\n  }\n\n private:\n  /*! \\brief A boolean flag indicating if the result is okay or error. */\n  bool unwrapped_ = false;\n  /*! \\brief The internal optional okay value. */\n  std::optional<T> ok_value_;\n  /*! \\brief The internal optional error value. */\n  std::optional<E> err_value_;\n};\n\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SUPPORT_RESULT_H_\n"
  },
  {
    "path": "cpp/support/utils.h",
    "content": "/*!\n * Copyright (c) 2023-2025 by Contributors\n * \\file support/utils.h\n * \\brief Utility functions.\n */\n#ifndef MLC_LLM_SUPPORT_UTILS_H_\n#define MLC_LLM_SUPPORT_UTILS_H_\n\n#include <tvm/support/io.h>\n\n#include <sstream>\n#include <string>\n#include <vector>\n\n#include \"../../3rdparty/tvm/src/support/base64.h\"\n#include \"../../3rdparty/tvm/src/support/bytes_io.h\"\n\nnamespace mlc {\nnamespace llm {\n\n/*! \\brief Split the input string by the given delimiter character. */\ninline std::vector<std::string> Split(const std::string& str, char delim) {\n  std::string item;\n  std::istringstream is(str);\n  std::vector<std::string> ret;\n  while (std::getline(is, item, delim)) {\n    ret.push_back(item);\n  }\n  return ret;\n}\n\n/*!\n * \\brief Check whether the string starts with a given prefix.\n * \\param str The given string.\n * \\param prefix The given prefix.\n * \\return Whether the prefix matched.\n */\ninline bool StartsWith(const std::string& str, const char* prefix) {\n  size_t n = str.length();\n  for (size_t i = 0; i < n; i++) {\n    if (prefix[i] == '\\0') return true;\n    if (str.data()[i] != prefix[i]) return false;\n  }\n  // return true if the str is equal to the prefix\n  return prefix[n] == '\\0';\n}\n\n/*!\n * \\brief Get the base64 encoded result of a string.\n * \\param str The string to encode.\n * \\return The base64 encoded string.\n */\ninline std::string Base64Encode(std::string str) {\n  std::string result;\n  tvm::support::BytesOutStream m_stream(&result);\n  tvm::support::Base64OutStream b64stream(&m_stream);\n  static_cast<tvm::support::Stream*>(&b64stream)->Write(str);\n  b64stream.Finish();\n  return result;\n}\n\n/*!\n * \\brief Get the base64 decoded result of a string.\n * \\param str The string to decode.\n * \\return The base64 decoded string.\n */\ninline std::string Base64Decode(std::string str) {\n  std::string result;\n  tvm::support::BytesInStream m_stream(str);\n  tvm::support::Base64InStream b64stream(&m_stream);\n  b64stream.InitPosition();\n  static_cast<tvm::support::Stream*>(&b64stream)->Read(&result);\n  return result;\n}\n\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SUPPORT_UTILS_H_\n"
  },
  {
    "path": "cpp/support/vlm_utils.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file support/image_utils.cc\n */\n#include \"vlm_utils.h\"\n\n#include <cmath>\n\nnamespace mlc {\nnamespace llm {\n\nvoid CalculateResizeShape(tvm::runtime::Tensor image_data, std::string model_type,\n                          int* p_target_height, int* p_target_width) {\n  TVM_FFI_ICHECK_EQ(image_data->shape[3], 3) << \"Image format must be NHWC\";\n  int height = image_data->shape[1];\n  int width = image_data->shape[2];\n\n  if (\"phi3_v\" == model_type) {\n    const int hd_num = 4;\n    double ratio = static_cast<double>(width) / height;\n    int scale = 1;\n    while (scale * std::ceil(scale / ratio) <= hd_num) {\n      scale += 1;\n    }\n    scale -= 1;\n    *p_target_width = static_cast<int>(scale * 336);\n    *p_target_height = static_cast<int>(*p_target_width / ratio);\n  }\n}\n\nvoid CalculatePadShape(tvm::runtime::Tensor image_data, std::string model_type, int* p_pad_height,\n                       int* p_pad_width) {\n  TVM_FFI_ICHECK_EQ(image_data->shape[3], 3) << \"Image format must be NHWC\";\n  if (\"phi3_v\" == model_type) {\n    int resized_height = 0, resized_width = 0;\n    CalculateResizeShape(image_data, model_type, &resized_height, &resized_width);\n    int tar = (int)(ceil(resized_height / 336.0) * 336);\n    int top_padding = (int)((tar - resized_height) / 2);\n    int bottom_padding = tar - resized_height - top_padding;\n    TVM_FFI_ICHECK_EQ(tar, resized_height + top_padding + bottom_padding)\n        << \"Padding size not equal!\";\n    *p_pad_height = tar;\n    *p_pad_width = resized_width;\n  }\n}\n\nvoid CalculateCropShape(tvm::runtime::Tensor image_data, std::string model_type, int* p_crop_height,\n                        int* p_crop_width) {\n  TVM_FFI_ICHECK_EQ(image_data->shape[3], 3) << \"Image format must be NHWC\";\n  if (\"phi3_v\" == model_type) {\n    int pad_h = 0, pad_w = 0;\n    CalculatePadShape(image_data, model_type, &pad_h, &pad_w);\n    *p_crop_height = pad_h / 336;\n    *p_crop_width = pad_w / 336;\n  }\n}\n\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/support/vlm_utils.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file support/vlm_utils.h\n * \\brief Tools for debug purposes.\n */\n#ifndef MLC_LLM_SUPPORT_VLM_UTILS_H_\n#define MLC_LLM_SUPPORT_VLM_UTILS_H_\n\n#include <tvm/runtime/tensor.h>\n\n#include <string>\n\nnamespace mlc {\nnamespace llm {\n\n/*!\n * \\brief Calculate the target height and width for resizing an image based on the input data and\n * model type. \\param image_data The input image data as a TVM Tensor. \\param model_type The type\n * of the model influencing the resizing parameters (e.g., phi3v). \\param target_height Reference to\n * the variable where the calculated target height will be stored. \\param target_width Reference to\n * the variable where the calculated target width will be stored.\n */\nvoid CalculateResizeShape(tvm::runtime::Tensor image_data, std::string model_type,\n                          int* p_target_height, int* p_target_width);\n/*!\n * \\brief Calculate the padding height and width for an image based on the input data and model\n * type. \\param image_data The input image data as a TVM Tensor. \\param model_type The type of the\n * model influencing the padding parameters (e.g., phi3v). \\param pad_height Reference to the\n * variable where the calculated padding height will be stored. \\param pad_width Reference to the\n * variable where the calculated padding width will be stored.\n */\nvoid CalculatePadShape(tvm::runtime::Tensor image_data, std::string model_type, int* p_pad_height,\n                       int* p_pad_width);\n\n/*!\n * \\brief Calculate the cropping height and width for an image based on the input data and model\n * type. \\param image_data The input image data as a TVM Tensor. \\param model_type The type of the\n * model influencing the cropping parameters (e.g., phi3v). \\param crop_height Reference to the\n * variable where the calculated cropping height will be stored. \\param crop_width Reference to the\n * variable where the calculated cropping width will be stored.\n */\nvoid CalculateCropShape(tvm::runtime::Tensor image_data, std::string model_type, int* p_crop_height,\n                        int* p_crop_width);\n\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_SUPPORT_IMAGE_UTILS_H_\n"
  },
  {
    "path": "cpp/tokenizers/streamer.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file streamer.cc\n */\n\n#include \"streamer.h\"\n\n#include <tvm/ffi/function.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/runtime/int_tuple.h>\n\n#include <algorithm>\n#include <string>\n\n#include \"tokenizers.h\"\n\nnamespace mlc {\nnamespace llm {\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  TextStreamerObj::RegisterReflection();\n  StopStrHandlerObj::RegisterReflection();\n}\n\n/****************** TextStreamer ******************/\n\nTextStreamerObj::TextStreamerObj(Tokenizer tokenizer) : tokenizer_(std::move(tokenizer)) {}\n\nTextStreamer::TextStreamer(Tokenizer tokenizer) {\n  data_ = tvm::ffi::make_object<TextStreamerObj>(std::move(tokenizer));\n}\n\nstd::string TextStreamerObj::Put(const std::vector<int32_t>& delta_tokens) {\n  TVM_FFI_ICHECK(!finished_) << \"`put` is not expected to be invoked after finish.\";\n  if (delta_tokens.empty()) {\n    return \"\";\n  }\n\n  std::string ret;\n  // We process delta tokens one by one.\n  for (int32_t delta_token : delta_tokens) {\n    // push to pending tokens.\n    pending_tokens_.push_back(delta_token);\n\n    // all_tokens = prefix_tokens_ + pending_tokens_\n    std::vector<int32_t> all_tokens;\n    all_tokens.reserve(prefix_tokens_.size() + pending_tokens_.size());\n    all_tokens.insert(all_tokens.end(), prefix_tokens_.begin(), prefix_tokens_.end());\n    all_tokens.insert(all_tokens.end(), pending_tokens_.begin(), pending_tokens_.end());\n\n    // Decode prefix_tokens_ and all_tokens.\n    std::string prefix_str = prefix_tokens_.empty() ? \"\" : tokenizer_->Decode(prefix_tokens_);\n    std::string full_str = tokenizer_->Decode(all_tokens);\n\n    std::string validated_str;\n    std::vector<int32_t> new_pending_tokens;\n    if (full_str.compare(0, prefix_str.length(), prefix_str) == 0) {\n      // Case 1. prefix_str is a prefix of `full_str`.\n      // validated_str = full_str[len(prefix_str):]\n      validated_str = full_str.substr(prefix_str.length());\n      // Pop UTF-8 replacement character from the back of pending tokens.\n      // - The UTF-8 replacement character take 3 chars.\n      // - A valid UTF-8 has 4 chars at most.\n      //   So there will be at most 3 tokens popped.\n      while (!pending_tokens_.empty() &&                         //\n             static_cast<int>(new_pending_tokens.size()) < 3 &&  //\n             validated_str.length() >= 3 &&                      //\n             validated_str.compare(validated_str.length() - 3, /*n=*/3, kReplacementCharacter) ==\n                 0) {\n        new_pending_tokens.push_back(pending_tokens_.back());\n        pending_tokens_.pop_back();\n        all_tokens.pop_back();\n        validated_str = tokenizer_->Decode(all_tokens).substr(prefix_str.length());\n      }\n    } else {\n      // Case 2. prefix_str is not a prefix of `full_str`.\n      // Pop pending tokens from the back.\n      // - Pop until prefix_str is indeed a prefix of full_str.\n      // - A valid UTF-8 has 4 chars at most.\n      //   So there will be at most 3 tokens popped.\n      // - If there are no more than 3 pending tokens, skip popping.\n      //   This is because it is impossible to make full_str contain\n      //   prefix_str without popping all the pending tokens.\n      if (static_cast<int>(pending_tokens_.size()) < 3) {\n        continue;\n      }\n      bool get_valid_full_str = false;\n      while (!pending_tokens_.empty() && static_cast<int>(new_pending_tokens.size()) < 3) {\n        new_pending_tokens.push_back(pending_tokens_.back());\n        pending_tokens_.pop_back();\n        all_tokens.pop_back();\n        full_str = tokenizer_->Decode(all_tokens);\n        if (full_str.compare(0, prefix_str.length(), prefix_str) == 0) {\n          get_valid_full_str = true;\n          break;\n        }\n      }\n\n      if (get_valid_full_str) {\n        // We find a full_str which starts from prefix_str.\n        // So we return the sliced full string without the prefix.\n        validated_str = full_str.substr(prefix_str.length());\n      } else {\n        // We cannot find a full_str which starts from prefix_str by\n        // popping 3 tokens.\n        // In this case, the remaining pending tokens are invalid UTF-8\n        // characters already, so we return the decoded pending tokens.\n        validated_str = tokenizer_->Decode(pending_tokens_);\n      }\n    }\n\n    if (!pending_tokens_.empty()) {\n      // Set the new prefix.\n      prefix_tokens_ = pending_tokens_;\n    }\n    std::reverse(new_pending_tokens.begin(), new_pending_tokens.end());\n    pending_tokens_ = new_pending_tokens;\n    ret += validated_str;\n  }\n  return ret;\n}\n\nstd::string TextStreamerObj::Finish() {\n  // all_tokens = prefix_tokens_ + pending_tokens_\n  std::vector<int32_t> all_tokens;\n  all_tokens.reserve(prefix_tokens_.size() + pending_tokens_.size());\n  all_tokens.insert(all_tokens.end(), prefix_tokens_.begin(), prefix_tokens_.end());\n  all_tokens.insert(all_tokens.end(), pending_tokens_.begin(), pending_tokens_.end());\n\n  // Decode prefix_tokens_ and all_tokens.\n  std::string prefix_str = prefix_tokens_.empty() ? \"\" : tokenizer_->Decode(prefix_tokens_);\n  std::string full_str = all_tokens.empty() ? \"\" : tokenizer_->Decode(all_tokens);\n\n  finished_ = true;\n  if (full_str.compare(0, prefix_str.length(), prefix_str) == 0) {\n    // Case 1. prefix_str is a prefix of `full_str`.\n    return full_str.substr(prefix_str.length());\n  } else {\n    // Case 2. prefix_str is not a prefix of `full_str`.\n    // In this case, the remaining pending tokens are invalid UTF-8\n    // characters already, so we return the decoded pending tokens.\n    return tokenizer_->Decode(pending_tokens_);\n  }\n}\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  namespace refl = tvm::ffi::reflection;\n  refl::GlobalDef()\n      .def(\"mlc.tokenizers.TextStreamer\",\n           [](Tokenizer tokenizer) { return TextStreamer(std::move(tokenizer)); })\n      .def(\"mlc.tokenizers.TextStreamerPut\",\n           [](TextStreamer text_streamer, const IntTuple& delta_tokens) {\n             return text_streamer->Put(\n                 {delta_tokens->data, delta_tokens->data + delta_tokens->size});\n           })\n      .def_method(\"mlc.tokenizers.TextStreamerFinish\", &TextStreamerObj::Finish);\n}\n\n/****************** StopStrHandler ******************/\n\n/*! \\brief Create the KMP partial match table for the input string. */\ninline std::vector<int> CreatePartialMatchTable(const String& str) {\n  int length = str.length();\n  std::vector<int> partial_match_table = {-1};\n  partial_match_table.reserve(length);\n  for (int i = 1; i < length; ++i) {\n    int ptr = partial_match_table[i - 1];\n    while (ptr != -1 && str.at(ptr) != str.at(i - 1)) {\n      ptr = partial_match_table[ptr];\n    }\n    partial_match_table.push_back(ptr + 1);\n  }\n  return partial_match_table;\n}\n\nStopStrHandlerObj::StopStrHandlerObj(Array<String> stop_strs,\n                                     const std::vector<std::string>& token_table)\n    : stop_strs_(std::move(stop_strs)), token_table_(token_table) {\n  int num_stop_strs = stop_strs_.size();\n  cur_match_lengths_.resize(num_stop_strs, 0);\n\n  // Create the KMP partial match table for each stop string.\n  partial_match_tables_.reserve(num_stop_strs);\n  for (const String& stop_str : stop_strs_) {\n    TVM_FFI_ICHECK(!stop_str.empty()) << \"Stop string cannot be empty.\";\n    partial_match_tables_.push_back(CreatePartialMatchTable(stop_str));\n  }\n}\n\nvoid StopStrHandlerObj::Put(int32_t token_id, std::vector<int64_t>* return_token_ids) {\n  TVM_FFI_ICHECK_NOTNULL(return_token_ids);\n\n  // Return the input token id if there is no stop string.\n  if (stop_strs_.empty()) {\n    return_token_ids->push_back(token_id);\n    return;\n  }\n\n  TVM_FFI_ICHECK(!stop_triggered_) << \"Cannot put new token when already stopped.\";\n\n  TVM_FFI_ICHECK_LT(token_id, static_cast<int>(token_table_.size()));\n  const std::string& token = token_table_[token_id];\n  pending_token_ids_.push_back(token_id);\n  pending_token_lengths_.push_back(token.length());\n\n  for (char ch : token) {\n    // The earliest starting point of stop string.\n    int stop_starting_pos = std::numeric_limits<int>::max();\n    // The cutoff length that can be safely return.\n    int cutoff_length = std::numeric_limits<int>::max();\n    // The maximum matched length.\n    int max_match_length = 0;\n\n    for (int str_id = 0; str_id < static_cast<int>(stop_strs_.size()); ++str_id) {\n      // - Run one step of KMP algorithm.\n      const std::vector<int>& partial_match_table = partial_match_tables_[str_id];\n      int& cur_match_length = cur_match_lengths_[str_id];\n      while (cur_match_length != -1 && ch != stop_strs_[str_id].at(cur_match_length)) {\n        cur_match_length = partial_match_table[cur_match_length];\n      }\n      ++cur_match_length;\n\n      // Case 1. The stop string is matched.\n      if (cur_match_length == stop_strs_[str_id].length()) {\n        stop_triggered_ = true;\n        stop_starting_pos =\n            std::min(stop_starting_pos,\n                     pending_string_len_ + 1 - static_cast<int>(stop_strs_[str_id].length()));\n        continue;\n      }\n\n      // Case 2. The stop string is not matched.\n      // - Get the cutoff length that can be safely return.\n      TVM_FFI_ICHECK_GE(pending_string_len_ + 1, cur_match_length);\n      cutoff_length = std::min(cutoff_length, pending_string_len_ + 1 - cur_match_length);\n      // - Get the updated pending string length.\n      max_match_length = std::max(max_match_length, cur_match_length);\n    }\n\n    // Collect the token ids that can be safely cut off and returned.\n    if (stop_triggered_) {\n      cutoff_length = stop_starting_pos;\n    }\n    TVM_FFI_ICHECK_NE(cutoff_length, std::numeric_limits<int>::max());\n    TVM_FFI_ICHECK_GE(cutoff_length, 0);\n    int cum_length = 0;\n    while (!pending_token_ids_.empty() &&\n           cum_length + pending_token_lengths_.front() <= cutoff_length) {\n      cum_length += pending_token_lengths_.front();\n      return_token_ids->push_back(pending_token_ids_.front());\n      pending_token_ids_.erase(pending_token_ids_.begin());\n      pending_token_lengths_.erase(pending_token_lengths_.begin());\n    }\n    if (stop_triggered_) {\n      return;\n    }\n\n    TVM_FFI_ICHECK_LE(cum_length, cutoff_length);\n    // `cum_length` is the prefix length what we actually cut off.\n    pending_string_len_ = (cutoff_length - cum_length) + max_match_length;\n  }\n}\n\nStopStrHandler::StopStrHandler(Array<String> stop_strs,\n                               const std::vector<std::string>& token_table) {\n  data_ = tvm::ffi::make_object<StopStrHandlerObj>(std::move(stop_strs), token_table);\n}\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  namespace refl = tvm::ffi::reflection;\n  refl::GlobalDef()\n      .def(\"mlc.tokenizers.StopStrHandler\",\n           [](Array<String> stop_strs, const Tokenizer& tokenizer) {\n             return StopStrHandler(std::move(stop_strs), tokenizer->PostProcessedTokenTable());\n           })\n      .def(\"mlc.tokenizers.StopStrHandlerPut\",\n           [](StopStrHandler handler, int token_id) {\n             std::vector<int64_t> delta_tokens;\n             handler->Put(token_id, &delta_tokens);\n             return IntTuple(std::move(delta_tokens));\n           })\n      .def(\"mlc.tokenizers.StopStringHandlerFinish\",\n           [](StopStrHandler handler) {\n             std::vector<int64_t> remaining_token_ids;\n             handler->Finish(&remaining_token_ids);\n             return IntTuple(std::move(remaining_token_ids));\n           })\n      .def_method(\"mlc.tokenizers.StopStrHandlerStopTriggered\", &StopStrHandlerObj::StopTriggered);\n}\n\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/tokenizers/streamer.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file streamer.h\n * \\brief Header of streamers in MLC LLM.\n */\n\n#ifndef MLC_LLM_STREAMER_H_\n#define MLC_LLM_STREAMER_H_\n\n#include <tvm/ffi/container/array.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/ffi/string.h>\n#include <tvm/runtime/object.h>\n\n#include \"tokenizers.h\"\n\nnamespace mlc {\nnamespace llm {\n\nusing namespace tvm::runtime;\n\n/****************** TextStreamer ******************/\n\n/*!\n * \\brief The class that streams back validated utf-8 text strings\n * that generated by tokenizer.\n */\nclass TextStreamerObj : public Object {\n public:\n  explicit TextStreamerObj(Tokenizer tokenizer);\n\n  /*!\n   * \\brief Put new delta tokens into the streamer, and get the UTF-8-valid\n   * delta string. The text streamer may hold some of the input delta tokens\n   * which cannot decode into valid UTF-8 strings. The returned string\n   * is always guaranteed to be UTF-8 valid.\n   * \\param delta_tokens The new tokens to put into the streamer.\n   * \\return The decoded delta string after putting the input new tokens.\n   */\n  std::string Put(const std::vector<int32_t>& delta_tokens);\n\n  /*! \\brief Return the string decoded by remaining tokens. */\n  std::string Finish();\n\n  // REPLACEMENT CHARACTER (U+FFFD) in UTF-8.\n  static constexpr const char* kReplacementCharacter = \"\\xef\\xbf\\xbd\";\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<TextStreamerObj>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO(\"mlc.TextStreamer\", TextStreamerObj, Object);\n\n private:\n  Tokenizer tokenizer_;\n  std::vector<int32_t> prefix_tokens_;\n  std::vector<int32_t> pending_tokens_;\n  bool finished_ = false;\n};\n\n/*!\n * \\brief Managed reference to TextStreamerObj\n * \\sa TextStreamerObj\n */\nclass TextStreamer : public ObjectRef {\n public:\n  /*! \\brief Construct a text streamer with tokenizer. */\n  explicit TextStreamer(Tokenizer tokenizer);\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TextStreamer, ObjectRef, TextStreamerObj);\n};\n\n/****************** StopStrHandler ******************/\n\n/*!\n * \\brief The stop string handler in MLC LLM, which takes input delta tokens\n * one at a time, and return the output delta token before stopping due to\n * stop strings.\n */\nclass StopStrHandlerObj : public Object {\n public:\n  explicit StopStrHandlerObj(Array<String> stop_strs, const std::vector<std::string>& token_table);\n\n  /*!\n   * \\brief Add new input delta token to the handler, push the output\n   * delta tokens before stopping into the given vector.\n   * The stop string handler may hold some of the input delta token\n   * which may be part of a stop string.\n   * The returned tokens are always guaranteed not to be part of stop string.\n   */\n  void Put(int32_t token_id, std::vector<int64_t>* return_token_ids);\n\n  /*!\n   * \\brief Stop string handling has finished, append the remaining\n   * cached token ids into the given vector.\n   */\n  void Finish(std::vector<int64_t>* return_token_ids) const {\n    return_token_ids->insert(return_token_ids->end(), pending_token_ids_.begin(),\n                             pending_token_ids_.end());\n  };\n\n  /*! \\brief Check if the generation has stopped due to stop string. */\n  bool StopTriggered() const { return stop_triggered_; }\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<StopStrHandlerObj>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO_FINAL(\"mlc.StopStrHandler\", StopStrHandlerObj, Object);\n\n private:\n  /*! \\brief The stop strings. */\n  Array<String> stop_strs_;\n  /*! \\brief The partial match table for each stop string in the KMP algorithm. */\n  std::vector<std::vector<int>> partial_match_tables_;\n  /*! \\brief The tokenizer token table for token id lookup. */\n  const std::vector<std::string>& token_table_;\n\n  /************ Global states across all stop strings. ************/\n\n  /*! \\brief The globally pending string length. */\n  int pending_string_len_ = 0;\n  /*! \\brief The globally pending token ids. */\n  std::vector<int32_t> pending_token_ids_;\n  /*! \\brief The token string length of each pending token id. */\n  std::vector<int> pending_token_lengths_;\n  /*! \\brief A boolean flag indicating if stop has been triggered. */\n  bool stop_triggered_ = false;\n\n  /************ Per-stop-string states. ************/\n\n  /*! \\brief The current match position of the pending string to each stop string. */\n  std::vector<int> cur_match_lengths_;\n};\n\n/*!\n * \\brief Managed reference to StopStrHandlerObj\n * \\sa StopStrHandlerObj\n */\nclass StopStrHandler : public ObjectRef {\n public:\n  explicit StopStrHandler(Array<String> stop_strs, const std::vector<std::string>& token_table);\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(StopStrHandler, ObjectRef, StopStrHandlerObj);\n};\n\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_STREAMER_H_\n"
  },
  {
    "path": "cpp/tokenizers/tokenizers.cc",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file tokenizer.cc\n */\n\n#include \"tokenizers.h\"\n\n#include <tokenizers_cpp.h>\n#include <tvm/ffi/extra/json.h>\n#include <tvm/ffi/function.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/runtime/int_tuple.h>\n#include <tvm/runtime/logging.h>\n\n#include <array>\n#include <filesystem>\n#include <fstream>\n#include <string>\n#include <string_view>\n\n#include \"./../support/encoding.h\"\n#include \"./../support/load_bytes_from_file.h\"\n\nnamespace mlc {\nnamespace llm {\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  TokenizerInfoNode::RegisterReflection();\n  TokenizerObj::RegisterReflection();\n}\n\n#ifndef COMPILE_MLC_WASM_RUNTIME\n\nString TokenizerInfoNode::AsJSONString() const {\n  tvm::ffi::json::Object obj;\n  obj.Set(\"token_postproc_method\", token_postproc_method);\n  obj.Set(\"prepend_space_in_encode\", prepend_space_in_encode);\n  obj.Set(\"strip_space_in_decode\", strip_space_in_decode);\n  return tvm::ffi::json::Stringify(obj);\n}\n\nTokenizerInfo TokenizerInfo::FromJSONString(String json_string) {\n  tvm::ffi::String err;\n  auto v = tvm::ffi::json::Parse(json_string, &err);\n  TVM_FFI_ICHECK(err.empty()) << \"Failed to parse JSON: \" << err;\n\n  TVM_FFI_ICHECK(v.try_cast<tvm::ffi::json::Object>().has_value()) << \"JSON must be an object.\";\n  const auto& obj = v.cast<tvm::ffi::json::Object>();\n\n  ObjectPtr<TokenizerInfoNode> n = tvm::ffi::make_object<TokenizerInfoNode>();\n  if (obj.count(\"token_postproc_method\")) {\n    TVM_FFI_ICHECK(obj.at(\"token_postproc_method\").try_cast<tvm::ffi::String>().has_value());\n    n->token_postproc_method = obj.at(\"token_postproc_method\").cast<tvm::ffi::String>();\n  }\n  if (obj.count(\"prepend_space_in_encode\")) {\n    TVM_FFI_ICHECK(obj.at(\"prepend_space_in_encode\").try_cast<bool>().has_value());\n    n->prepend_space_in_encode = obj.at(\"prepend_space_in_encode\").cast<bool>();\n  }\n  if (obj.count(\"strip_space_in_decode\")) {\n    TVM_FFI_ICHECK(obj.at(\"strip_space_in_decode\").try_cast<bool>().has_value());\n    n->strip_space_in_decode = obj.at(\"strip_space_in_decode\").cast<bool>();\n  }\n\n  return TokenizerInfo(n);\n}\n\nTokenizer::Tokenizer(std::unique_ptr<tokenizers::Tokenizer> tokenizer, TokenizerInfo info) {\n  ObjectPtr<TokenizerObj> n = tvm::ffi::make_object<TokenizerObj>();\n  n->tokenizer = std::move(tokenizer);\n  n->info_ = std::move(info);\n  data_ = std::move(n);\n}\n\nstd::vector<int32_t> TokenizerObj::Encode(const std::string& text) const {\n  return tokenizer->Encode(text);\n}\n\nstd::vector<int32_t> TokenizerObj::EncodeNoPrependSpace(const std::string& text) const {\n  // TODO(yixin): now this only supports tokenizers with tokenizer.json\n  // other tokenizers should be supported.\n  static const constexpr char* kPaddingPrefix = \"\\x01\";\n  if (!info_->prepend_space_in_encode) {\n    return tokenizer->Encode(text);\n  }\n\n  auto result = tokenizer->Encode(kPaddingPrefix + text);\n  // remove the first two tokens: \"▁\" and \"<0x01>\"\n  result.erase(result.begin(), result.begin() + 2);\n  return result;\n}\n\nstd::vector<std::vector<int32_t>> TokenizerObj::EncodeBatch(const Array<String>& texts) const {\n  std::vector<std::string> texts_vec;\n  for (const String& text : texts) {\n    texts_vec.push_back(text);\n  }\n  return tokenizer->EncodeBatch(texts_vec);\n}\n\nstd::string TokenizerObj::Decode(const std::vector<int32_t>& token_ids) const {\n  return tokenizer->Decode(token_ids);\n}\n\nconst DynamicBitset& TokenizerObj::GetPrefixTokenMask() {\n  if (prefix_token_mask_.Size() != 0) {\n    return prefix_token_mask_;\n  }\n\n  int vocab_size = GetVocabSize();\n  prefix_token_mask_ = DynamicBitset(vocab_size);\n\n  // Sort all tokens\n  const auto& token_table = PostProcessedTokenTable();\n  std::vector<std::pair<std::string, int>> sorted_tokens;\n  for (int32_t token_id = 0; token_id < vocab_size; ++token_id) {\n    sorted_tokens.emplace_back(token_table[token_id], token_id);\n  }\n  std::sort(sorted_tokens.begin(), sorted_tokens.end());\n\n  // Check every token if it is a prefix of another token\n  for (int idx = 0; idx < vocab_size - 1; ++idx) {\n    auto cur_token = sorted_tokens[idx].first;\n    auto nxt_token = sorted_tokens[idx + 1].first;\n    if (cur_token.length() <= nxt_token.length() &&\n        std::string_view(nxt_token).substr(0, cur_token.length()) == cur_token) {\n      prefix_token_mask_.Set(sorted_tokens[idx].second);\n    }\n  }\n\n  return prefix_token_mask_;\n}\n\nsize_t TokenizerObj::GetVocabSize() const { return tokenizer->GetVocabSize(); }\n\nstd::string TokenizerObj::IdToToken(int32_t token_id) const {\n  return tokenizer->IdToToken(token_id);\n}\n\nint32_t TokenizerObj::TokenToId(const std::string& token) const {\n  return tokenizer->TokenToId(token);\n}\n\nTokenizer Tokenizer::FromPath(const String& _path, std::optional<TokenizerInfo> info) {\n  TokenizerInfo info_value = info.value_or(DetectTokenizerInfo(_path));\n  std::filesystem::path path{std::string(_path)};\n  std::filesystem::path sentencepiece;\n  std::filesystem::path huggingface;\n  std::filesystem::path rwkvworld;\n  TVM_FFI_ICHECK(std::filesystem::exists(path)) << \"Cannot find tokenizer via path: \" << _path;\n  if (std::filesystem::is_directory(path)) {\n    sentencepiece = path / \"tokenizer.model\";\n    huggingface = path / \"tokenizer.json\";\n    rwkvworld = path / \"tokenizer_model\";\n  } else {\n    sentencepiece = path.parent_path() / \"tokenizer.model\";\n    huggingface = path.parent_path() / \"tokenizer.json\";\n    rwkvworld = path.parent_path() / \"tokenizer_model\";\n  }\n  if (std::filesystem::exists(huggingface)) {\n    // Check HuggingFace\n    return Tokenizer(tokenizers::Tokenizer::FromBlobJSON(LoadBytesFromFile(huggingface.string())),\n                     info_value);\n  }\n  if (std::filesystem::exists(sentencepiece)) {\n    // Check SentencePiece\n    LOG(WARNING)\n        << \"Using `tokenizer.model` since we cannot locate `tokenizer.json`.\\n\"\n        << \"It is recommended to use `tokenizer.json` to ensure all token mappings are included, \"\n        << \"since currently, files like `added_tokens.json`, `tokenizer_config.json` are ignored.\\n\"\n        << \"Consider converting `tokenizer.model` to `tokenizer.json` by compiling the model \"\n        << \"with MLC again, or see if MLC's huggingface provides this file.\";\n    return Tokenizer(\n        tokenizers::Tokenizer::FromBlobSentencePiece(LoadBytesFromFile(sentencepiece.string())),\n        info_value);\n  }\n  {\n    // Check ByteLevelBPE\n    std::filesystem::path merges_path = path / \"merges.txt\";\n    std::filesystem::path vocab_path = path / \"vocab.json\";\n    std::filesystem::path added_tokens_path = path / \"added_tokens.json\";\n    if (std::filesystem::exists(merges_path) && std::filesystem::exists(vocab_path) &&\n        std::filesystem::exists(added_tokens_path)) {\n      std::string vocab = LoadBytesFromFile(vocab_path.string());\n      std::string merges = LoadBytesFromFile(merges_path.string());\n      std::string added_tokens = LoadBytesFromFile(added_tokens_path.string());\n      return Tokenizer(tokenizers::Tokenizer::FromBlobByteLevelBPE(vocab, merges, added_tokens),\n                       info_value);\n    }\n  }\n  if (std::filesystem::exists(rwkvworld)) {\n    // Check RWKV\n    return Tokenizer(tokenizers::Tokenizer::FromBlobRWKVWorld(rwkvworld.string()), info_value);\n  }\n  LOG(FATAL) << \"Cannot find any tokenizer under: \" << _path;\n}\n\nTokenizerInfo Tokenizer::DetectTokenizerInfo(const String& path_str) {\n  std::filesystem::path path{std::string(path_str)};\n  TVM_FFI_ICHECK(std::filesystem::exists(path)) << \"Cannot find tokenizer via path: \" << path_str;\n  if (!std::filesystem::is_directory(path)) {\n    path = path.parent_path();\n  }\n  path = path / \"tokenizer.json\";\n  if (!std::filesystem::exists(path)) {\n    LOG(WARNING) << \"Tokenizer info is not detected as tokenizer.json is not found. The default \"\n                 << \"tokenizer info will be used.\";\n    return TokenizerInfo(tvm::ffi::make_object<TokenizerInfoNode>());\n  }\n\n  std::string tokenizer_json = LoadBytesFromFile(path.string());\n  tvm::ffi::String err;\n  auto v = tvm::ffi::json::Parse(tokenizer_json, &err);\n  TVM_FFI_ICHECK(err.empty()) << \"Failed to parse JSON: \" << err;\n  TVM_FFI_ICHECK(v.try_cast<tvm::ffi::json::Object>().has_value()) << \"JSON must be an object.\";\n  const auto& obj = v.cast<tvm::ffi::json::Object>();\n\n  ObjectPtr<TokenizerInfoNode> n = tvm::ffi::make_object<TokenizerInfoNode>();\n\n  // Step 1. Detect token_postproc_method: byte_fallback or byte_level\n  // Detect {\"type\": \"ByteLevel\"} or {\"type\": \"ByteFallback\"} in \"decoder\" field of the tokenizer\n  if (!obj.count(\"decoder\") || !obj.at(\"decoder\").try_cast<tvm::ffi::json::Object>().has_value()) {\n    LOG(WARNING) << \"Decoder field is not found in tokenizer.json. Use ByteFallback as default.\";\n    n->token_postproc_method = \"byte_fallback\";\n  } else {\n    auto decoder_obj = obj.at(\"decoder\").cast<tvm::ffi::json::Object>();\n    TVM_FFI_ICHECK(decoder_obj.count(\"type\") &&\n                   decoder_obj.at(\"type\").try_cast<tvm::ffi::String>().has_value());\n    auto type = decoder_obj.at(\"type\").cast<tvm::ffi::String>();\n\n    auto f_detect_decoder_type = [](ObjectPtr<TokenizerInfoNode> n,\n                                    const tvm::ffi::json::Value& decoder_json) {\n      TVM_FFI_ICHECK(decoder_json.try_cast<tvm::ffi::json::Object>().has_value());\n      TVM_FFI_ICHECK(decoder_json.cast<tvm::ffi::json::Object>().count(\"type\") &&\n                     decoder_json.cast<tvm::ffi::json::Object>()\n                         .at(\"type\")\n                         .try_cast<tvm::ffi::String>()\n                         .has_value());\n      auto type = decoder_json.cast<tvm::ffi::json::Object>().at(\"type\").cast<tvm::ffi::String>();\n      if (type == \"ByteLevel\") {\n        n->token_postproc_method = \"byte_level\";\n        return true;\n      } else if (type == \"ByteFallback\") {\n        n->token_postproc_method = \"byte_fallback\";\n        return true;\n      }\n      return false;\n    };\n\n    bool found = false;\n\n    // For sequence, examine every decoder\n    if (type == \"Sequence\") {\n      TVM_FFI_ICHECK(decoder_obj.count(\"decoders\") &&\n                     decoder_obj.at(\"decoders\").try_cast<tvm::ffi::json::Array>().has_value());\n      for (const tvm::ffi::json::Value& decoder :\n           decoder_obj.at(\"decoders\").cast<tvm::ffi::json::Array>()) {\n        if (f_detect_decoder_type(n, decoder)) {\n          found = true;\n        }\n      }\n    } else {\n      if (f_detect_decoder_type(n, obj.at(\"decoder\"))) {\n        found = true;\n      }\n    }\n\n    if (!found) {\n      LOG(WARNING) << \"Neither ByteLevel nor ByteFallback decoder is detected in tokenizer.json. \"\n                   << \"Use ByteFallback as default.\";\n      n->token_postproc_method = \"byte_fallback\";\n    }\n  }\n\n  // Step 2. Detect prepend_space_in_encode\n  // Find {\"type\": \"Prepend\", \"prepend\": \"▁\"} in \"normalizer\" field of the tokenizer\n  if (obj.count(\"normalizer\") &&\n      obj.at(\"normalizer\").try_cast<tvm::ffi::json::Object>().has_value()) {\n    const tvm::ffi::json::Value& normalizer_json = obj.at(\"normalizer\");\n\n    auto f_handle_normalizer = [](ObjectPtr<TokenizerInfoNode> n,\n                                  const tvm::ffi::json::Value& normalizer_json) {\n      TVM_FFI_ICHECK(normalizer_json.try_cast<tvm::ffi::json::Object>().has_value());\n      auto obj = normalizer_json.cast<tvm::ffi::json::Object>();\n      TVM_FFI_ICHECK(obj.count(\"type\") && obj.at(\"type\").try_cast<tvm::ffi::String>().has_value());\n      if (obj.at(\"type\").cast<tvm::ffi::String>() == \"Prepend\" && obj.count(\"prepend\") &&\n          obj.at(\"prepend\").try_cast<tvm::ffi::String>().has_value() &&\n          obj.at(\"prepend\").cast<tvm::ffi::String>() == \"\\xe2\\x96\\x81\") {\n        n->prepend_space_in_encode = true;\n        return true;\n      }\n      return false;\n    };\n\n    auto type = normalizer_json.cast<tvm::ffi::json::Object>().at(\"type\").cast<tvm::ffi::String>();\n    if (type == \"Sequence\") {\n      TVM_FFI_ICHECK(normalizer_json.cast<tvm::ffi::json::Object>().count(\"normalizers\") &&\n                     normalizer_json.cast<tvm::ffi::json::Object>()\n                         .at(\"normalizers\")\n                         .try_cast<tvm::ffi::json::Array>()\n                         .has_value());\n      for (const tvm::ffi::json::Value& normalizer : normalizer_json.cast<tvm::ffi::json::Object>()\n                                                         .at(\"normalizers\")\n                                                         .cast<tvm::ffi::json::Array>()) {\n        if (f_handle_normalizer(n, normalizer)) {\n          break;\n        }\n      }\n    } else {\n      f_handle_normalizer(n, normalizer_json);\n    }\n  }\n\n  // Step 3. Detect strip_space_in_decode\n  // Find {\"type\": \"Strip\", \"content\": \" \", \"start\": 1, \"stop\": 0} in \"decoder\" field of the\n  // tokenizer\n  if (obj.count(\"decoder\") && obj.at(\"decoder\").try_cast<tvm::ffi::json::Object>().has_value()) {\n    const tvm::ffi::json::Value& decoders_json = obj.at(\"decoder\");\n\n    auto f_handle_decoder = [](ObjectPtr<TokenizerInfoNode> n,\n                               const tvm::ffi::json::Value& decoder_json) {\n      TVM_FFI_ICHECK(decoder_json.try_cast<tvm::ffi::json::Object>().has_value());\n      auto obj = decoder_json.cast<tvm::ffi::json::Object>();\n      TVM_FFI_ICHECK(obj.count(\"type\") && obj.at(\"type\").try_cast<tvm::ffi::String>().has_value());\n      if (obj.at(\"type\").cast<tvm::ffi::String>() == \"Strip\" && obj.count(\"content\") &&\n          obj.at(\"content\").try_cast<tvm::ffi::String>().has_value() &&\n          obj.at(\"content\").cast<tvm::ffi::String>() == \" \" && obj.count(\"start\") &&\n          obj.at(\"start\").try_cast<int64_t>().has_value() && obj.at(\"start\").cast<int64_t>() == 1 &&\n          obj.count(\"stop\") && obj.at(\"stop\").try_cast<int64_t>().has_value() &&\n          obj.at(\"stop\").cast<int64_t>() == 0) {\n        n->strip_space_in_decode = true;\n        return true;\n      }\n      return false;\n    };\n\n    auto type = decoders_json.cast<tvm::ffi::json::Object>().at(\"type\").cast<tvm::ffi::String>();\n    if (type == \"Sequence\") {\n      TVM_FFI_ICHECK(decoders_json.cast<tvm::ffi::json::Object>().count(\"decoders\") &&\n                     decoders_json.cast<tvm::ffi::json::Object>()\n                         .at(\"decoders\")\n                         .try_cast<tvm::ffi::json::Array>()\n                         .has_value());\n      for (const tvm::ffi::json::Value& decoder : decoders_json.cast<tvm::ffi::json::Object>()\n                                                      .at(\"decoders\")\n                                                      .cast<tvm::ffi::json::Array>()) {\n        if (f_handle_decoder(n, decoder)) {\n          break;\n        }\n      }\n    } else {\n      f_handle_decoder(n, decoders_json);\n    }\n  }\n\n  return TokenizerInfo(n);\n}\n#endif\n\n/*! \\brief ByteFallback decoder: transform tokens like <0x1B> to hex char byte 1B */\ninline std::string ByteFallbackDecoder(const std::string& token) {\n  if (token.length() == 6 && token.substr(0, 3) == \"<0x\" && token.back() == '>') {\n    int byte = 0;\n    for (int i = 0; i < 2; ++i) {\n      byte *= 16;\n      byte +=\n          token[3 + i] >= '0' && token[3 + i] <= '9' ? token[3 + i] - '0' : token[3 + i] - 'A' + 10;\n    }\n    TVM_FFI_ICHECK(byte >= 0 && byte < 256);\n    return std::string(/*n=*/1, static_cast<char>(byte));\n  }\n  return token;\n}\n\n/*! \\brief SpaceReplacer decoder: transform \"\\u2581\" back to space */\ninline std::string SpaceReplacerDecoder(const std::string& token) {\n  // \\u2581 is the unicode for \"lower one eighth block\"\n  // UTF8 encoding for \\u2581 is 0xE2 0x96 0x81\n  std::string result;\n  for (size_t i = 0; i < token.size(); ++i) {\n    if (i + 2 < token.size() && token[i] == char(0xE2) && token[i + 1] == char(0x96) &&\n        token[i + 2] == char(0x81)) {\n      result += ' ';\n      i += 2;\n    } else {\n      result += token[i];\n    }\n  }\n  return result;\n}\n\n/*! \\brief ByteLevel decoder: inverses the bytes-to-unicode transformation in the encoding\n * process as in\n * https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59\n */\ninline std::string ByteLevelDecoder(const std::string& token) {\n  // clang-format off\n  // The inverse map of bytes_to_unicode. -1 means there is no mapping to this unicode.\n  static const std::array<int, 324> char_to_byte_map = {\n    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,\n    46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68,\n    69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91,\n    92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,\n    112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, -1, -1, -1, -1,\n    -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,\n    -1, -1, -1, -1, -1, -1, -1, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, -1,\n    174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191,\n    192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209,\n    210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227,\n    228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245,\n    246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,\n    13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 127, 128,\n    129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146,\n    147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 173\n  };\n  // clang-format on\n\n  auto unicode_codepoints = ParseUTF8(token.c_str(), UTF8ErrorPolicy::kReturnInvalid);\n  if (unicode_codepoints.size() == 1 && unicode_codepoints[0] == kInvalidUTF8) {\n    return token;\n  }\n\n  std::string decoded;\n\n  for (auto unicode_codepoint : unicode_codepoints) {\n    TVM_FFI_ICHECK(unicode_codepoint >= 0);\n    if (unicode_codepoint >= static_cast<int>(char_to_byte_map.size()) ||\n        char_to_byte_map[unicode_codepoint] == -1) {\n      // If there is no mapping, return the original token\n      return token;\n    }\n    decoded += static_cast<char>(char_to_byte_map[unicode_codepoint]);\n  }\n  return decoded;\n}\n\n/*!\n * \\brief Post-process a raw token to the actual token with the given post-processing method.\n */\ninline std::string PostProcessToken(const std::string& token,\n                                    const std::string& token_postproc_method) {\n  if (token_postproc_method == \"byte_fallback\") {\n    return SpaceReplacerDecoder(ByteFallbackDecoder(token));\n  } else if (token_postproc_method == \"byte_level\") {\n    return ByteLevelDecoder(token);\n  } else {\n    LOG(FATAL) << \"Unknown post-processing method: \" << token_postproc_method;\n  }\n}\n\nstd::vector<std::string> Tokenizer::PostProcessTokenTable(\n    const std::vector<std::string>& token_table, const std::string& token_postproc_method) {\n  std::vector<std::string> post_processed_token_table;\n  post_processed_token_table.reserve(token_table.size());\n  for (const std::string& token : token_table) {\n    post_processed_token_table.push_back(PostProcessToken(token, token_postproc_method));\n  }\n  return post_processed_token_table;\n}\n\n#ifndef COMPILE_MLC_WASM_RUNTIME\nconst std::vector<std::string>& TokenizerObj::PostProcessedTokenTable() {\n  if (!post_processed_token_table_.empty()) {\n    return post_processed_token_table_;\n  }\n\n  std::vector<std::string> raw_token_table;\n  int vocab_size = tokenizer->GetVocabSize();\n  raw_token_table.reserve(vocab_size);\n  for (int32_t token_id = 0; token_id < vocab_size; ++token_id) {\n    raw_token_table.push_back(tokenizer->IdToToken(token_id));\n  }\n  post_processed_token_table_ =\n      Tokenizer::PostProcessTokenTable(raw_token_table, info_->token_postproc_method);\n  return post_processed_token_table_;\n}\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  namespace refl = tvm::ffi::reflection;\n  refl::GlobalDef()\n      .def(\"mlc.tokenizers.Tokenizer\", [](const String& path) { return Tokenizer::FromPath(path); })\n      .def(\"mlc.tokenizers.TokenizerEncode\",\n           [](const Tokenizer& tokenizer, const String& text) {\n             std::vector<int32_t> token_ids = tokenizer->Encode(text);\n             return IntTuple{token_ids.begin(), token_ids.end()};\n           })\n      .def(\"mlc.tokenizers.TokenizerEncodeBatch\",\n           [](const Tokenizer& tokenizer, const Array<String>& texts) {\n             std::vector<std::vector<int32_t>> results = tokenizer->EncodeBatch(texts);\n             Array<IntTuple> ret;\n             ret.reserve(results.size());\n             for (const auto& result : results) {\n               ret.push_back(IntTuple{result.begin(), result.end()});\n             }\n             return ret;\n           })\n      .def(\"mlc.tokenizers.TokenizerDecode\",\n           [](const Tokenizer& tokenizer, const IntTuple& token_ids) {\n             return tokenizer->Decode({token_ids->data, token_ids->data + token_ids->size});\n           })\n      .def(\"mlc.tokenizers.DetectTokenizerInfo\",\n           [](const String& path) { return Tokenizer::DetectTokenizerInfo(path)->AsJSONString(); });\n}\n\n#endif\n\nTVM_FFI_STATIC_INIT_BLOCK() {\n  namespace refl = tvm::ffi::reflection;\n  refl::GlobalDef()\n      .def_packed(\"mlc.tokenizers.PostProcessTokenTable\",\n                  [](tvm::ffi::PackedArgs args, tvm::ffi::Any* rv) {\n                    Array<String> token_table_arr = args[0].cast<Array<String>>();\n                    std::string token_postproc_method = args[args.size() - 1].cast<String>();\n                    std::vector<std::string> token_table;\n                    for (int i = 0; i < token_table_arr.size(); ++i) {\n                      token_table.push_back(token_table_arr[i]);\n                    }\n                    std::vector<std::string> processed_token_table =\n                        Tokenizer::PostProcessTokenTable(token_table, token_postproc_method);\n\n                    // Convert std::vector<std::string> to Array<String>\n                    Array<String> processed_token_table_tvm;\n                    for (int i = 0; i < processed_token_table.size(); ++i) {\n                      processed_token_table_tvm.push_back(processed_token_table[i]);\n                    }\n                    *rv = processed_token_table_tvm;\n                  })\n      .def(\"mlc.tokenizers.PostProcessToken\",\n           [](const String& token, const String& token_postproc_method) {\n             return PostProcessToken(token, token_postproc_method);\n           });\n}\n\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "cpp/tokenizers/tokenizers.h",
    "content": "/*!\n *  Copyright (c) 2023-2025 by Contributors\n * \\file tokenizers.h\n * \\brief Header of tokenizer related functions.\n */\n\n#ifndef MLC_LLM_TOKENIZER_H_\n#define MLC_LLM_TOKENIZER_H_\n\n#include <tokenizers_cpp.h>\n#include <tvm/ffi/container/array.h>\n#include <tvm/ffi/reflection/registry.h>\n#include <tvm/ffi/string.h>\n#include <tvm/runtime/object.h>\n\n#include <optional>\n#include <unordered_map>\n\n#include \"../base.h\"\n#include \"../support/dynamic_bitset.h\"\n\nnamespace mlc {\nnamespace llm {\n\nusing namespace tvm::runtime;\nusing tvm::ffi::Array;\nusing tvm::ffi::String;\n\n/*! \\brief Useful information of the tokenizer during generation. */\nclass TokenizerInfoNode : public Object {\n public:\n  /*! \\brief The method to post-process the tokens to their original strings.\n   * Possible values (each refers to a kind of tokenizer):\n   * - \"byte_fallback\": The same as the byte-fallback BPE tokenizer, including LLaMA-2,\n   *   Mixtral-7b, etc. E.g. \"▁of\" -> \" of\", \"<0x1B>\" -> \"\\x1B\".\n   *   This method:\n   *   1) Transform tokens like <0x1B> to hex char byte 1B. (so-called byte-fallback)\n   *   2) Replace \\\\u2581 \"▁\" with space.\n   * - \"byte_level\": The same as the byte-level BPE tokenizer, including LLaMA-3, GPT-2,\n   *   Phi-2, etc. E.g. \"Ġin\" -> \" in\", \"ě\" -> \"\\x1B\"\n   *   This method inverses the bytes-to-unicode transformation in the encoding process in\n   *   https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59\n   */\n  String token_postproc_method = \"byte_fallback\";\n  /*! \\brief Whether to prepend a space during encoding. */\n  bool prepend_space_in_encode = false;\n  /*! \\brief Whether to strip the first space during decoding. */\n  bool strip_space_in_decode = false;\n\n  String AsJSONString() const;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<TokenizerInfoNode>();\n  }\n\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO(\"mlc.serve.TokenizerInfo\", TokenizerInfoNode, Object);\n};\n\nclass TokenizerInfo : public ObjectRef {\n public:\n  /*! \\brief Create a TokenizerInfo object from a dumped string. */\n  static TokenizerInfo FromJSONString(String json_string);\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(TokenizerInfo, ObjectRef, TokenizerInfoNode);\n};\n\n/*! \\brief A wrapper object class for tokenizer. */\nclass TokenizerObj : public Object {\n public:\n  /*! \\brief The underlying tokenizer. */\n  std::unique_ptr<tokenizers::Tokenizer> tokenizer;\n\n  /*! \\brief Encode text into ids. */\n  std::vector<int32_t> Encode(const std::string& text) const;\n\n  /*! \\brief Encode text into ids. Some tokenizers may prepend a space in encoding, this method\n   * guarantees the space is not prepended. */\n  std::vector<int32_t> EncodeNoPrependSpace(const std::string& text) const;\n\n  /*! \\brief Encode texts into ids. */\n  std::vector<std::vector<int32_t>> EncodeBatch(const Array<String>& texts) const;\n\n  /*! \\brief Decode token ids into text. */\n  std::string Decode(const std::vector<int32_t>& token_ids) const;\n\n  /*! \\brief Return the post-processed token table of the tokenizer. Special tokens are included. */\n  const std::vector<std::string>& PostProcessedTokenTable();\n\n  /*! \\brief Get the prefix token mask as a bitset. The tokens which is a prefix of another token\n   * are set to true, and others are set to false in the bitset. */\n  const DynamicBitset& GetPrefixTokenMask();\n\n  /*!\n   * \\brief Returns the vocabulary size. Special tokens are considered. This may be smaller than the\n   * `vocab_size` in config.json (length of logits), see https://github.com/QwenLM/Qwen2/issues/147\n   * and https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/discussions/47.\n   */\n  size_t GetVocabSize() const;\n\n  /*!\n   * \\brief Convert the given id to its corresponding token if it exists. If not, return an\n   * empty string.\n   */\n  std::string IdToToken(int32_t token_id) const;\n\n  /*!\n   * \\brief Convert the given token to its corresponding id if it exists. If not, return -1.\n   */\n  int32_t TokenToId(const std::string& token) const;\n\n  static void RegisterReflection() {\n    namespace refl = tvm::ffi::reflection;\n    refl::ObjectDef<TokenizerObj>();\n  }\n\n  friend class Tokenizer;\n  static constexpr const bool _type_has_method_sequal_reduce = false;\n  static constexpr const bool _type_has_method_shash_reduce = false;\n  static constexpr const bool _type_mutable = true;\n  TVM_FFI_DECLARE_OBJECT_INFO_FINAL(\"mlc.Tokenizer\", TokenizerObj, Object);\n\n private:\n  /*! \\brief Useful information of the tokenizer during generation. */\n  TokenizerInfo info_;\n  /*! \\brief The cached token table. */\n  std::vector<std::string> post_processed_token_table_;\n  /*! \\brief The cached prefix token mask. */\n  DynamicBitset prefix_token_mask_;\n};\n\nclass Tokenizer : public ObjectRef {\n public:\n  /*!\n   * \\brief Create a tokenizer from a directory path on disk.\n   * \\param path The path to the tokenizer or the tokenizer directory.\n   * \\param info The tokenizer info. If not provided, the info will be detected automatically.\n   */\n  MLC_LLM_DLL static Tokenizer FromPath(const String& path,\n                                        std::optional<TokenizerInfo> info = std::nullopt);\n\n  /*! \\brief Detect the tokenizer info from the given path of the tokenizer. */\n  MLC_LLM_DLL static TokenizerInfo DetectTokenizerInfo(const String& path);\n\n  /*!\n   * \\brief Post-process the token table to their original strings.\n   * \\param token_table The raw token table.\n   * \\param postproc_method The postprocessing method to use.\n   * \\returns The postprocessed token table containing the original strings.\n   */\n  static std::vector<std::string> PostProcessTokenTable(const std::vector<std::string>& token_table,\n                                                        const std::string& token_postproc_method);\n\n  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Tokenizer, ObjectRef, TokenizerObj);\n\n private:\n  explicit Tokenizer(std::unique_ptr<tokenizers::Tokenizer> tokenizer, TokenizerInfo info);\n};\n\n}  // namespace llm\n}  // namespace mlc\n\n#endif  // MLC_LLM_TOKENIZER_H_\n"
  },
  {
    "path": "docs/.gitignore",
    "content": "_build/\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= python -m sphinx\nSOURCEDIR     = .\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/README.md",
    "content": "# MLC-LLM Documentation\n\nThe documentation was built upon [Sphinx](https://www.sphinx-doc.org/en/master/).\n\n## Dependencies\n\nRun the following command in this directory to install dependencies first:\n\n```bash\npip3 install -r requirements.txt\n```\n\n## Build the Documentation\n\nThen you can build the documentation by running:\n\n```bash\nmake html\n```\n\n## View the Documentation\n\nRun the following command to start a simple HTTP server:\n\n```bash\ncd _build/html\npython3 -m http.server\n```\n\nThen you can view the documentation in your browser at `http://localhost:8000` (the port can be customized by appending ` -p PORT_NUMBER` in the python command above).\n"
  },
  {
    "path": "docs/community/faq.rst",
    "content": ".. _FAQ:\n\nFrequently Asked Questions\n==========================\n\nThis is a list of Frequently Asked Questions (FAQ) about the MLC-LLM. Feel free to suggest new entries!\n\n... How can I customize the temperature, and repetition penalty of models?\n   Please check our :ref:`configure-mlc-chat-json` tutorial.\n\n... What's the quantization algorithm MLC-LLM using?\n   Please check our :doc:`/compilation/configure_quantization` tutorial.\n\n... Why do I encounter an error ``free(): invalid pointer, Aborted (core dumped)`` at the end of model compilation?\n   This happens if you compiled TVM from source and didn't hide LLVM symbols in cmake configurations.\n   Please follow our instructions in :ref:`Building TVM from Source  <tvm-build-from-source>` tutorial to compile TVM which hides LLVM symbols, or use our pre-built MLC-LLM :doc:`pip wheels <../install/mlc_llm>`.\n"
  },
  {
    "path": "docs/community/guideline.rst",
    "content": ".. _community_guide:\n\nCommunity Guideline\n===================\n\n.. contents::\n  :depth: 2\n  :local:\n\nWelcome to the MLC-LLM community! Just like you, all of us are in awe of the immense power of large language models.\nOur goal for MLC-LLM is to foster a project that is driven by an open-source community, working together to democratize\nthis technology and make it accessible across various devices. We are thrilled to have you as part of our\ncommunity and eagerly anticipate your valuable contributions.\n\n\n.. _community_discussion:\n\nParticipate in Community Discussions\n------------------------------------\n\nWe encourage open discussions. If you encounter a bug or have a feature request, please file an issue in MLC-LLM's\nGitHub `issue tracker <https://github.com/mlc-ai/mlc-llm/issues>`__. You are encouraged to tag the issue with labels\nsuch as \"bug,\" \"feature request,\" or \"iOS\" so that the relevant developers can quickly notice your concern.\n\nAdditionally, we have set up a `discord server <https://discord.gg/9Xpy2HGBuD>`__ for online discussions.\nWhile we encourage participation in the Discord server, we also recommend creating a GitHub issue even if the\ntopic has been discussed there. This ensures that the discussion is archived and searchable for future reference.\n\nBefore submitting an issue, we kindly ask you to check our :doc:`/community/faq` to see if your question has already been answered.\n\n.. _contribute-to-mlc-llm:\n\nContribute to MLC-LLM\n---------------------\n\n.. _fork-and-create-pull-requests:\n\nFork and Create Pull Requests\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nReady to contribute to MLC-LLM? Awesome! We are excited to see you are ready to contribute your code.\nThe standard way to make changes to MLC-LLM code base is through creating a `pull-request <https://github.com/mlc-ai/mlc-llm/pulls>`__,\nand we will review your code and merge it to the code base when it is ready.\n\nThe first step to becoming a developer is to `fork <https://github.com/mlc-ai/mlc-llm/fork>`__ the repository to your own\ngithub account, you will notice a repository under ``https://github.com/username/mlc-llm`` where ``username`` is your github user name.\n\nYou can clone your fork to your local machine and commit changes, or edit the contents of your fork (in the case you are just fixing typos)\non GitHub directly. Once your update is complete, you can click the ``contribute`` button and open a pull request to the main repository.\n\n.. _contribute-new-models:\n\nContribute New Models to MLC-LLM\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n* If you have compiled a model using our :doc:`/compilation/compile_models` tutorial for an existing model architecture, please upload your models to the internet (e.g., Hugging Face) by following :ref:`distribute-compiled-models` tutorial.\n\n* If you add a new model variant to MLC-LLM by following our :doc:`/compilation/define_new_models` tutorial.\n  Please create a pull request to add your model architecture (currently model architectures are placed under\n  `relax_models <https://github.com/mlc-ai/mlc-llm/tree/main/mlc_llm/relax_model>`__ folder).\n\n.. _coding-styles:\n\nCoding Styles\n^^^^^^^^^^^^^\n\nFor python codes, we generally follow the `PEP8 style guide <https://peps.python.org/pep-0008/>`__.\nThe python comments follow `NumPy style <https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_numpy.html>`__ python docstrings.\nTo make things easy, you can use `black <https://pypi.org/project/black/>`__ to automatically format your python code.\n\n.. code:: bash\n\n      pip install black\n      black your_python_file.py\n\nFor C++ codes, we generally follow the `Google C++ style guide <https://google.github.io/styleguide/cppguide.html>`__.\nThe C++ comments should be `Doxygen compatible <http://www.doxygen.nl/manual/docblocks.html#cppblock>`__.\nFo your convenience, you can use `clang-format <https://clang.llvm.org/docs/ClangFormat.html>`__ to automatically format your C++ code.\n\n.. code:: bash\n\n      clang-format -i your_cpp_file.cpp\n\n.. _general-development-process:\n\nGeneral Development Process\n---------------------------\n\nEveryone in the community is welcome to send patches, documents, and propose new directions to the project.\nThe key guideline here is to enable everyone in the community to get involved and participate in the decision and development.\nWe encourage public discussion in different channels, so that everyone in the community can participate\nand get informed in developments.\n\nCode reviews are one of the key ways to ensure the quality of the code. High-quality code reviews prevent technical debt\nfor long-term and are crucial to the success of the project. A pull request needs to be reviewed before it gets merged.\nA committer who has the expertise of the corresponding area would moderate the pull request and merge the code when\nit is ready. The corresponding committer could request multiple reviewers who are familiar with the area of the code.\nWe encourage contributors to request code reviews themselves and help review each other's code -- remember everyone\nis volunteering their time to the community, high-quality code review itself costs as much as the actual code\ncontribution, you could get your code quickly reviewed if you do others the same favor.\n\nThe community should strive to reach a consensus on technical decisions through discussion. We expect committers to\nmoderate technical discussions in a diplomatic way, and provide suggestions with clear technical reasoning when necessary.\n\n\n.. _roles-committers:\n\nCommitters\n^^^^^^^^^^\n\nCommitters are individuals who are granted with write access to the project. A committer is usually responsible for\na certain area or several areas of the code where they oversee the code review process.\nThe area of contribution can take all forms, including code contributions and code reviews, documents, education, and outreach.\nThe review of pull requests will be assigned to the committers who recently contribute to the area this PR belongs to.\nCommitters are essential for a high quality and healthy project. The community actively looks for new committers\nfrom contributors. Each existing committer can nominate new committers to MLC projects.\n\n.. _roles-contributors:\n\nContributors\n^^^^^^^^^^^^\nWe also welcome contributors if you are not ready to be a committer yet. Everyone who contributes to\nthe project (in the form of code, bugfix, documentation, tutorials, etc) is a contributor.\nWe maintain a `page <https://github.com/mlc-ai/mlc-llm/blob/main/CONTRIBUTORS.md>`__ to acknowledge contributors,\nplease let us know if you contribute to the project and if your name is not included in the list.\n"
  },
  {
    "path": "docs/compilation/compile_models.rst",
    "content": ".. _compile-model-libraries:\n\nCompile Model Libraries\n=======================\n\nTo run a model with MLC LLM in any platform, we need:\n\n1. **Model weights** converted to MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC <https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/tree/main>`__.)\n2. **Model library** that comprises the inference logic\n\nThis page describes how to compile a model library with MLC LLM. Model compilation optimizes\nthe model inference for a given platform, allowing users bring their own new model\narchitecture, use different quantization modes, and customize the overall model\noptimization flow.\n\n\n\nNotably, in many cases you do not need to explicit call compile.\n\n- If you are using the Python API, you can skip specifying ``model_lib`` and\n  the system will JIT compile the library.\n\n- If you are building iOS/android package, checkout :ref:`package-libraries-and-weights`,\n  which provides a simpler high-level command that leverages the compile behind the scheme.\n\n\nThis page is still helpful to understand the compilation flow behind the scheme,\nor be used to explicit create model libraries.\nWe compile ``RedPajama-INCITE-Chat-3B-v1`` with ``q4f16_1`` as an example for all platforms.\n\n.. note::\n    Before you proceed, make sure you followed :ref:`install-tvm`, a required\n    backend to compile models with MLC LLM.\n\n    Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python-engine` to obtain\n    the CLI app / Python API that can be used to chat with the compiled model.\n\n\n.. contents:: Table of Contents\n    :depth: 1\n    :local:\n\n1. Verify Installation\n----------------------\n\n**Step 1. Verify mlc_llm**\n\nWe use the python package ``mlc_llm`` to compile models. This can be installed by\nfollowing :ref:`install-mlc-packages`, either by building from source, or by\ninstalling the prebuilt package. Verify ``mlc_llm`` installation in command line via:\n\n.. code:: bash\n\n    $ mlc_llm --help\n    # You should see help information with this line\n    usage: MLC LLM Command Line Interface. [-h] {compile,convert_weight,gen_config}\n\n.. note::\n    If it runs into error ``command not found: mlc_llm``, try ``python -m mlc_llm --help``.\n\n**Step 2. Verify TVM**\n\nTo compile models, you also need to follow :ref:`install-tvm`.\nHere we verify ``tvm`` quickly with command line (for full verification, see :ref:`tvm-validate`):\n\n.. code:: bash\n\n    $ python -c \"import tvm; print(tvm.__file__)\"\n    /some-path/lib/python3.13/site-packages/tvm/__init__.py\n\n1. Clone from HF and convert_weight\n-----------------------------------\n\nThis replicates :ref:`convert-weights-via-MLC`, see that page for more details.\n\nYou can be under the mlc-llm repo, or your own working directory. Note that all platforms\ncan share the same compiled/quantized weights.\n\n.. code:: shell\n\n    # Create directory\n    mkdir -p dist/models && cd dist/models\n    # Clone HF weights\n    git lfs install\n    git clone https://huggingface.co/togethercomputer/RedPajama-INCITE-Chat-3B-v1\n    cd ../..\n    # Convert weight\n    mlc_llm convert_weight ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \\\n        --quantization q4f16_1 \\\n        -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\n\n2. Generate mlc-chat-config and compile\n---------------------------------------\n\nA model library is specified by:\n\n - The model architecture (e.g. ``llama-2``, ``gpt-neox``)\n - Quantization (e.g. ``q4f16_1``, ``q0f32``)\n - Metadata (e.g. ``context_window_size``, ``sliding_window_size``, ``prefill-chunk-size``), which affects memory planning\n - Platform (e.g. ``cuda``, ``webgpu``, ``iOS``)\n\nAll these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_config``.\n\n.. code:: shell\n\n    # Create output directory for the model library compiled\n    mkdir dist/libs\n\n.. tabs::\n\n    .. group-tab:: Linux - CUDA\n\n        .. code:: shell\n\n            # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n            mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \\\n                --quantization q4f16_1 --conv-template redpajama_chat \\\n                -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/\n            # 2. compile: compile model library with specification in mlc-chat-config.json\n            mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \\\n                --device cuda -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so\n\n\n    .. group-tab:: Metal\n\n        For M-chip Mac:\n\n        .. code:: shell\n\n            # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n            mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \\\n                --quantization q4f16_1 --conv-template redpajama_chat \\\n                -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/\n            # 2. compile: compile model library with specification in mlc-chat-config.json\n            mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \\\n                --device metal -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal.so\n\n        Cross-Compiling for Intel Mac on M-chip Mac:\n\n        .. code:: shell\n\n            # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n            mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \\\n                --quantization q4f16_1 --conv-template redpajama_chat \\\n                -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/\n            # 2. compile: compile model library with specification in mlc-chat-config.json\n            mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \\\n                --device metal:x86-64 -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal_x86_64.dylib\n\n        For Intel Mac:\n\n        .. code:: shell\n\n            # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n            mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \\\n                --quantization q4f16_1 --conv-template redpajama_chat \\\n                -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/\n            # 2. compile: compile model library with specification in mlc-chat-config.json\n            mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \\\n                --device metal -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal_x86_64.dylib\n\n\n    .. group-tab:: Vulkan\n\n        For Linux:\n\n        .. code:: shell\n\n            # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n            mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \\\n                --quantization q4f16_1 --conv-template redpajama_chat \\\n                -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/\n            # 2. compile: compile model library with specification in mlc-chat-config.json\n            mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \\\n                --device vulkan -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.so\n\n        For Windows:\n\n        .. code:: shell\n\n            # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n            mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \\\n                --quantization q4f16_1 --conv-template redpajama_chat \\\n                -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/\n            # 2. compile: compile model library with specification in mlc-chat-config.json\n            mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \\\n                --device vulkan -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.dll\n\n    .. group-tab:: iOS/iPadOS\n\n        You need a Mac to compile models for it.\n\n        .. code:: shell\n\n            # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n            mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ --quantization q4f16_1 \\\n                --conv-template redpajama_chat --context-window-size 768 \\\n                -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/\n            # 2. compile: compile model library with specification in mlc-chat-config.json\n            mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \\\n                --device iphone -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar\n\n        .. note::\n            If it runs into error\n\n            .. code:: text\n\n                Compilation error:\n                xcrun: error: unable to find utility \"metal\", not a developer tool or in PATH\n                xcrun: error: unable to find utility \"metallib\", not a developer tool or in PATH\n\n            , please check and make sure you have Command Line Tools for Xcode installed correctly.\n            You can use ``xcrun metal`` to validate: when it prints ``metal: error: no input files``, it means the Command Line Tools for Xcode is installed and can be found, and you can proceed with the model compiling.\n\n    .. group-tab:: Android\n\n        .. code:: shell\n\n            # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n            mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ --quantization q4f16_1 \\\n                --conv-template redpajama_chat --context-window-size 768 \\\n                -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/\n            # 2. compile: compile model library with specification in mlc-chat-config.json\n            mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \\\n                --device android -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-android.tar\n\n    .. group-tab:: WebGPU\n\n        .. code:: shell\n\n            # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n            mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \\\n                --quantization q4f16_1 --conv-template redpajama_chat \\\n                -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/\n            # 2. compile: compile model library with specification in mlc-chat-config.json\n            mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \\\n                --device webgpu -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm\n\n        .. note::\n            To compile for webgpu, you need to build from source when installing ``mlc_llm``. Besides, you also need to follow :ref:`install-web-build`.\n            Otherwise, it would run into error\n\n            .. code:: text\n\n                RuntimeError: Cannot find libraries: wasm_runtime.bc\n\n        .. note::\n            For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill-chunk-size 1024`` or lower ``--context-window-size`` to decrease memory usage.\n            Otherwise, you may run into issues like:\n\n            .. code:: text\n\n                TypeError: Failed to execute 'createBuffer' on 'GPUDevice': Failed to read the 'size' property from\n                'GPUBufferDescriptor': Value is outside the 'unsigned long long' value range.\n\n.. note::\n\n    For the ``conv-template``, `conversation_template.py <https://github.com/mlc-ai/mlc-llm/blob/main/python/mlc_llm/conversation_template.py>`__\n    contains a full list of conversation templates that MLC provides. If the model you are adding\n    requires a new conversation template, you would need to add your own.\n    Follow `this PR <https://github.com/mlc-ai/mlc-llm/pull/2163>`__ as an example.\n    However, adding your own template would require you :ref:`build mlc_llm from source <mlcchat_build_from_source>`\n    in order for it to be recognized by the runtime.\n\n    For more details, please see :ref:`configure-mlc-chat-json`.\n\n3. Verify output and chat\n-------------------------\n\nBy executing the compile command above, we generate the model weights, model lib, and a chat config.\nWe can check the output with the commands below:\n\n.. tabs::\n\n    .. group-tab:: Linux - CUDA\n\n        .. code:: shell\n\n            ~/mlc-llm > ls dist/libs\n              RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so      # ===> the model library\n\n            ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\n              mlc-chat-config.json                             # ===> the chat config\n              tensor-cache.json                               # ===> the model weight info\n              params_shard_0.bin                               # ===> the model weights\n              params_shard_1.bin\n              ...\n              tokenizer.json                                   # ===> the tokenizer files\n              tokenizer_config.json\n\n        We can now chat with the model using the command line interface (CLI) app or the Python API.\n\n        .. code:: shell\n\n            python\n            >>> from mlc_llm import MLCEngine\n            >>> engine = MLCEngine(model=\"./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\",\n            ...   model_lib=\"./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-cuda.so\")\n            >>> engine.chat.completions.create(\n            ...   messages=[{\"role\": \"user\", \"content\": \"hello\"}]\n            ... )\n            ChatCompletionResponse(\n              choices=[ChatCompletionResponseChoice(\n                message=ChatCompletionMessage(\n                  content=\"Hi! How can I assist you today?\", role='assistant'\n                )\n              )],\n              ...\n            )\n\n    .. group-tab:: Metal\n\n        .. code:: shell\n\n            ~/mlc-llm > ls dist/libs\n              RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal.so     # ===> the model library (will be -metal_x86_64.dylib for Intel Mac)\n\n            ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\n              mlc-chat-config.json                             # ===> the chat config\n              tensor-cache.json                               # ===> the model weight info\n              params_shard_0.bin                               # ===> the model weights\n              params_shard_1.bin\n              ...\n              tokenizer.json                                   # ===> the tokenizer files\n              tokenizer_config.json\n\n        We can now chat with the model using the command line interface (CLI) app or the Python API.\n\n        .. code:: shell\n\n            python\n            >>> from mlc_llm import MLCEngine\n            >>> engine = MLCEngine(model=\"./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\",\n            ...   model_lib=\"./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal.so\")\n            >>> engine.chat.completions.create(\n            ...   messages=[{\"role\": \"user\", \"content\": \"hello\"}]\n            ... )\n            ChatCompletionResponse(\n              choices=[ChatCompletionResponseChoice(\n                message=ChatCompletionMessage(\n                  content=\"Hi! How can I assist you today?\", role='assistant'\n                )\n              )],\n              ...\n            )\n\n\n    .. group-tab:: Vulkan\n\n        .. code:: shell\n\n            ~/mlc-llm > ls dist/libs\n              RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.so    # ===> the model library (will be .dll for Windows)\n\n            ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\n              mlc-chat-config.json                             # ===> the chat config\n              tensor-cache.json                               # ===> the model weight info\n              params_shard_0.bin                               # ===> the model weights\n              params_shard_1.bin\n              ...\n              tokenizer.json                                   # ===> the tokenizer files\n              tokenizer_config.json\n\n        We can now chat with the model using the command line interface (CLI) app or the Python API.\n\n        .. code:: shell\n\n            python\n            >>> from mlc_llm import MLCEngine\n            >>> engine = MLCEngine(model=\"./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\",\n            ...   model_lib=\"./dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-vulkan.so\")\n            >>> engine.chat.completions.create(\n            ...   messages=[{\"role\": \"user\", \"content\": \"hello\"}]\n            ... )\n            ChatCompletionResponse(\n              choices=[ChatCompletionResponseChoice(\n                message=ChatCompletionMessage(\n                  content=\"Hi! How can I assist you today?\", role='assistant'\n                )\n              )],\n              ...\n            )\n\n    .. group-tab:: iOS/iPadOS\n\n        .. code:: shell\n\n            ~/mlc-llm > ls dist/libs\n              RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar   # ===> the model library\n\n            ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\n              mlc-chat-config.json                             # ===> the chat config\n              tensor-cache.json                               # ===> the model weight info\n              params_shard_0.bin                               # ===> the model weights\n              params_shard_1.bin\n              ...\n              tokenizer.json                                   # ===> the tokenizer files\n              tokenizer_config.json\n\n        The model lib ``dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar``\n        will be packaged as a static library into the iOS app. Checkout :ref:`deploy-ios` for more details.\n\n    .. group-tab:: Android\n\n        .. code:: shell\n\n            ~/mlc-llm > ls dist/libs\n              RedPajama-INCITE-Chat-3B-v1-q4f16_1-android.tar  # ===> the model library\n\n            ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\n              mlc-chat-config.json                             # ===> the chat config\n              tensor-cache.json                               # ===> the model weight info\n              params_shard_0.bin                               # ===> the model weights\n              params_shard_1.bin\n              ...\n              tokenizer.json                                   # ===> the tokenizer files\n              tokenizer_config.json\n\n        The model lib ``dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-android.tar``\n        will be packaged as a static library into the android app. Checkout :ref:`deploy-android` for more details.\n\n    .. group-tab:: WebGPU\n\n        .. code:: shell\n\n            ~/mlc-llm > ls dist/libs\n              RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm  # ===> the model library\n\n            ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\n              mlc-chat-config.json                             # ===> the chat config\n              tensor-cache.json                               # ===> the model weight info\n              params_shard_0.bin                               # ===> the model weights\n              params_shard_1.bin\n              ...\n              tokenizer.json                                   # ===> the tokenizer files\n              tokenizer_config.json\n\n        To use this in WebGPU runtime, checkout :ref:`webllm-runtime`.\n\nCompile Commands for More Models\n--------------------------------\n\nThis section lists compile commands for more models that you can try out. Note that this can be easily\ngeneralized to any model variant, as long as mlc-llm supports the architecture.\n\n.. tabs::\n\n    .. tab:: Model: Llama-2-7B\n\n        Please `request for access <https://huggingface.co/meta-llama>`_ to the Llama-2 weights from Meta first.\n        After granted access, first create directory ``dist/models`` and download the model to the directory.\n        For example, you can run the following code:\n\n        .. code:: shell\n\n            mkdir -p dist/models && cd dist/models\n            git lfs install\n            git clone https://huggingface.co/meta-llama/Llama-2-7b-chat-hf\n            cd ../..\n\n        Then convert the HF weights into MLC-compatible weights. Note that all platforms\n        can share the same compiled/quantized weights.\n\n        .. code:: shell\n\n            mlc_llm convert_weight ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC\n\n        Afterwards, run the following command to generate mlc config and compile the model.\n\n        .. code:: shell\n\n            # Create output directory for the model library compiled\n            mkdir dist/libs\n\n        .. tabs::\n\n            .. tab:: Target: CUDA\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \\\n                        --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \\\n                        --device cuda -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-cuda.so\n\n            .. tab:: Metal\n\n                For M-chip Mac:\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \\\n                        --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \\\n                        --device metal -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-metal.so\n\n                Cross-Compiling for Intel Mac on M-chip Mac:\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \\\n                        --quantization q4f16_1 --conv-template redpajama_chat \\\n                        -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \\\n                        --device metal:x86-64 -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-metal_x86_64.dylib\n\n                For Intel Mac:\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \\\n                        --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \\\n                        --device metal -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-metal_x86_64.dylib\n\n            .. tab:: Vulkan\n\n                For Linux:\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \\\n                        --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \\\n                        --device vulkan -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-vulkan.so\n\n                For Windows:\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \\\n                        --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \\\n                        --device vulkan -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-vulkan.dll\n\n            .. tab:: WebGPU\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \\\n                        --context-window-size 2048 --conv-template llama-2 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \\\n                        --device webgpu -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-webgpu.wasm\n\n                .. note::\n                    To compile for webgpu, you need to build from source when installing ``mlc_llm``. Besides, you also need to follow :ref:`install-web-build`.\n                    Otherwise, it would run into error\n\n                    .. code:: text\n\n                        RuntimeError: Cannot find libraries: wasm_runtime.bc\n\n            .. tab:: iPhone/iPad\n\n                You need a Mac to compile models for it.\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \\\n                        --conv-template llama-2 --context-window-size 768 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \\\n                        --device iphone -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-iphone.tar\n\n            .. tab:: Android\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/Llama-2-7b-chat-hf/ --quantization q4f16_1 \\\n                        --conv-template llama-2 --context-window-size 768 -o dist/Llama-2-7b-chat-hf-q4f16_1-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json \\\n                        --device android -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-android.tar\n\n    .. tab:: Mistral-7B-Instruct-v0.2\n\n        Note that Mistral uses sliding window attention (SWA). Thus, instead of specifying\n        ``context-window-size``, we specify ``sliding-window-size``.\n\n        First create directory ``dist/models`` and download the model to the directory.\n        For example, you can run the following code:\n\n        .. code:: shell\n\n            mkdir -p dist/models && cd dist/models\n            git lfs install\n            git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2\n            cd ../..\n\n        Then convert the HF weights into MLC-compatible weights. Note that all platforms\n        can share the same compiled/quantized weights.\n\n        .. code:: shell\n\n            mlc_llm convert_weight ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \\\n                -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC\n\n        Afterwards, run the following command to generate mlc config and compile the model.\n\n        .. code:: shell\n\n            # Create output directory for the model library compiled\n            mkdir dist/libs\n\n        .. tabs::\n\n            .. tab:: Target: CUDA\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \\\n                        --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \\\n                        --device cuda -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-cuda.so\n\n            .. tab:: Metal\n\n                For M-chip Mac:\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \\\n                        --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \\\n                        --device metal -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-metal.so\n\n\n                For Intel Mac:\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \\\n                        --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \\\n                        --device metal -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-metal_x86_64.dylib\n\n            .. tab:: Vulkan\n\n                For Linux:\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \\\n                        --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \\\n                        --device vulkan -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-vulkan.so\n\n                For Windows:\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \\\n                        --conv-template mistral_default -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \\\n                        --device vulkan -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-vulkan.dll\n\n            .. tab:: WebGPU\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \\\n                        --prefill-chunk-size 1024 --conv-template mistral_default \\\n                        -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \\\n                        --device webgpu -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-webgpu.wasm\n\n                .. note::\n                    To compile for webgpu, you need to build from source when installing ``mlc_llm``. Besides, you also need to follow :ref:`install-web-build`.\n                    Otherwise, it would run into error\n\n                    .. code:: text\n\n                        RuntimeError: Cannot find libraries: wasm_runtime.bc\n\n                .. note::\n                    For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill-chunk-size 1024`` or lower ``--context-window-size`` to decrease memory usage.\n                    Otherwise, you may run into issues like:\n\n                    .. code:: text\n\n                        TypeError: Failed to execute 'createBuffer' on 'GPUDevice': Failed to read the 'size' property from\n                        'GPUBufferDescriptor': Value is outside the 'unsigned long long' value range.\n\n            .. tab:: iPhone/iPad\n\n                You need a Mac to compile models for it.\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \\\n                        --conv-template mistral_default --sliding-window-size 1024 --prefill-chunk-size 128  \\\n                        -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \\\n                        --device iphone -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-iphone.tar\n\n            .. tab:: Android\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/Mistral-7B-Instruct-v0.2/ --quantization q4f16_1 \\\n                        --conv-template mistral_default --sliding-window-size 1024 --prefill-chunk-size 128 -o dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/Mistral-7B-Instruct-v0.2-q4f16_1-MLC/mlc-chat-config.json \\\n                        --device android -o dist/libs/Mistral-7B-Instruct-v0.2-q4f16_1-android.tar\n\n    .. tab:: Other models\n\n        First create directory ``dist/models`` and download the model to the directory.\n        For example, you can run the following code:\n\n        .. code:: shell\n\n            mkdir -p dist/models && cd dist/models\n            git lfs install\n            git clone https://huggingface.co/DISTRIBUTOR/HF_MODEL\n            cd ../..\n\n        Then convert the HF weights into MLC-compatible weights. Note that all platforms\n        can share the same compiled/quantized weights.\n\n        .. code:: shell\n\n            mlc_llm convert_weight ./dist/models/HF_MODEL/ --quantization q4f16_1 -o dist/OUTPUT-MLC\n\n        Afterwards, run the following command to generate mlc config and compile the model.\n\n        .. code:: shell\n\n            # Create output directory for the model library compiled\n            mkdir dist/libs\n\n        .. tabs::\n\n            .. tab:: Target: CUDA\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device cuda -o dist/libs/OUTPUT-cuda.so\n\n            .. tab:: Metal\n\n                For M-chip Mac:\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device metal -o dist/libs/OUTPUT-metal.so\n\n\n                For Intel Mac:\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device metal -o dist/libs/OUTPUT-metal_x86_64.dylib\n\n            .. tab:: Vulkan\n\n                For Linux:\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device vulkan -o dist/libs/OUTPUT-vulkan.so\n\n                For Windows:\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device vulkan -o dist/libs/OUTPUT-vulkan.dll\n\n            .. tab:: WebGPU\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE -o dist/OUTPUT-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device webgpu -o dist/libs/OUTPUT-webgpu.wasm\n\n                .. note::\n                    To compile for webgpu, you need to build from source when installing ``mlc_llm``. Besides, you also need to follow :ref:`install-web-build`.\n                    Otherwise, it would run into error\n\n                    .. code:: text\n\n                        RuntimeError: Cannot find libraries: wasm_runtime.bc\n\n                .. note::\n                    For webgpu, when compiling larger models like ``Llama-2-7B``, you may want to add ``--prefill-chunk-size 1024`` or lower ``--context-window-size`` to decrease memory usage.\n                    Otherwise, you may run into issues like:\n\n                    .. code:: text\n\n                        TypeError: Failed to execute 'createBuffer' on 'GPUDevice': Failed to read the 'size' property from\n                        'GPUBufferDescriptor': Value is outside the 'unsigned long long' value range.\n\n            .. tab:: iPhone/iPad\n\n                You need a Mac to compile models for it.\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE \\\n                        --context-window-size 768 -o dist/OUTPUT-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device iphone -o dist/libs/OUTPUT-iphone.tar\n\n            .. tab:: Android\n\n                .. code:: shell\n\n                    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n                    mlc_llm gen_config ./dist/models/HF_MODEL/ --quantization q4f16_1 --conv-template CONV_TEMPLATE \\\n                        --context-window-size 768 -o dist/OUTPUT-MLC/\n                    # 2. compile: compile model library with specification in mlc-chat-config.json\n                    mlc_llm compile ./dist/OUTPUT-MLC/mlc-chat-config.json --device android -o dist/libs/OUTPUT-android.tar\n\nFor each model and each backend, the above only provides the most recommended build command (which is the most optimized).\nYou can also try with different argument values (e.g., different quantization modes, context window size, etc.),\nwhose build results affect runtime memory requirement, and it is possible that they may not run as\nfast and robustly as the provided one when running the model.\n\n.. note::\n    Uing 3-bit quantization usually can be overly aggressive and only works for limited settings.\n    If you encounter issues where the compiled model does not perform as expected,\n    consider utilizing a higher number of bits for quantization (e.g., 4-bit quantization).\n\nIf you are interested in distributing the model besides local execution, please checkout :ref:`distribute-compiled-models`.\n\n\n.. _compile-command-specification:\n\nCompile Command Specification\n-----------------------------\n\nAs you have seen in the section above, the model compilation is split into three steps: convert weights, generate\n``mlc-chat-config.json``, and compile the model. This section describes the list of options that can be used\nduring compilation.\n\n1. Convert Weight\n^^^^^^^^^^^^^^^^^\n\nWeight conversion command follows the pattern below:\n\n.. code:: text\n\n    mlc_llm convert_weight \\\n        CONFIG \\\n        --quantization QUANTIZATION_MODE \\\n        [--model-type MODEL_TYPE] \\\n        [--device DEVICE] \\\n        [--source SOURCE] \\\n        [--source-format SOURCE_FORMAT] \\\n        --output OUTPUT\n\nNote that ``CONFIG`` is a positional argument. Arguments wrapped with ``[ ]`` are optional.\n\n--CONFIG                            It can be one of the following:\n\n                                    1. Path to a HuggingFace model directory that contains a ``config.json`` or\n                                    2. Path to ``config.json`` in HuggingFace format, or\n                                    3. The name of a pre-defined model architecture.\n\n                                    A ``config.json`` file in HuggingFace format defines the model architecture, including the vocabulary\n                                    size, the number of layers, the hidden size, number of attention heads, etc.\n                                    Example: https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json.\n\n                                    A HuggingFace directory often contains a ``config.json`` which defines the model architecture,\n                                    the non-quantized model weights in PyTorch or SafeTensor format, tokenizer configurations,\n                                    as well as an optional ``generation_config.json`` provides additional default configuration for\n                                    text generation.\n                                    Example: https://huggingface.co/codellama/CodeLlama-7b-hf/tree/main.\n\n                                    For existing pre-defined model architecture, see ``MODEL_PRESETS``\n                                    `here <https://github.com/mlc-ai/mlc-llm/blob/main/python/mlc_llm/compiler/model/model.py>`_.\n\n--quantization QUANTIZATION_MODE    The quantization mode we use to compile.\n\n                                    See :ref:`quantization_mode` for more information.\n                                    Available options are: ``q0f16``, ``q0f32``, ``q3f16_1``, ``q4f16_1``, ``q4f32_1``, and\n                                    ``q4f16_awq``.\n\n                                    We encourage you to use 4-bit quantization, as the text generated by 3-bit\n                                    quantized models may have bad quality depending on the model.\n\n--model-type MODEL_TYPE             Model architecture such as \"llama\". If not set, it is inferred from ``config.json``.\n\n--device DEVICE                     The device used to do quantization such as \"cuda\" or \"cuda:0\". Will detect from\n                                    local available GPUs if not specified.\n\n--source SOURCE                     The path to original model weight, infer from ``config`` if missing.\n\n--source-format SOURCE_FORMAT       The format of source model weight, infer from ``config`` if missing.\n\n--output OUTPUT                     The output directory to save the quantized model weight.\n                                    Will create ``params_shard_*.bin`` and ```tensor-cache.json``` in this directory.\n\n2. Generate MLC Chat Config\n^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nIn order to compile a model, we first need to generate the ``mlc-chat-config.json``. This file contains specifications\nlike ``context-window-size`` and ``sliding-window-size``, among others that can alter the model compiled. We also process\ntokenizers in this step.\n\nConfig generation command follows the pattern below:\n\n.. code:: text\n\n    mlc_llm gen_config \\\n        CONFIG \\\n        --quantization QUANTIZATION_MODE \\\n        [--model-type MODEL_TYPE] \\\n        --conv-template CONV_TEMPLATE \\\n        [--context-window-size CONTEXT_WINDOW_SIZE] \\\n        [--sliding-window-size SLIDING_WINDOW_SIZE] \\\n        [--prefill-chunk-size PREFILL_CHUNK_SIZE] \\\n        [--tensor-parallel-shard TENSOR_PARALLEL_SHARDS] \\\n        --output OUTPUT\n\nNote that ``CONFIG`` is a positional argument. Arguments wrapped with ``[ ]`` are optional.\n\n--CONFIG                                        It can be one of the following:\n\n                                                1. Path to a HuggingFace model directory that contains a ``config.json`` or\n                                                2. Path to ``config.json`` in HuggingFace format, or\n                                                3. The name of a pre-defined model architecture.\n\n                                                A ``config.json`` file in HuggingFace format defines the model architecture, including the vocabulary\n                                                size, the number of layers, the hidden size, number of attention heads, etc.\n                                                Example: https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json.\n\n                                                A HuggingFace directory often contains a ``config.json`` which defines the model architecture,\n                                                the non-quantized model weights in PyTorch or SafeTensor format, tokenizer configurations,\n                                                as well as an optional ``generation_config.json`` provides additional default configuration for\n                                                text generation.\n                                                Example: https://huggingface.co/codellama/CodeLlama-7b-hf/tree/main.\n\n                                                For existing pre-defined model architecture, see ``MODEL_PRESETS``\n                                                `here <https://github.com/mlc-ai/mlc-llm/blob/main/python/mlc_llm/compiler/model/model.py>`_.\n\n--quantization QUANTIZATION_MODE                The quantization mode we use to compile.\n\n                                                See :ref:`quantization_mode` for more information.\n                                                Available options are: ``q0f16``, ``q0f32``, ``q3f16_1``, ``q4f16_1``, ``q4f32_1``, and\n                                                ``q4f16_awq``.\n\n                                                We encourage you to use 4-bit quantization, as the text generated by 3-bit\n                                                quantized models may have bad quality depending on the model.\n\n--model-type MODEL_TYPE                         Model architecture such as \"llama\". If not set, it is inferred from ``config.json``.\n\n--conv-template CONV_TEMPLATE                   Conversation template. It depends on how the model is tuned. Use \"LM\" for vanilla base model\n                                                For existing pre-defined templates, see ``CONV_TEMPLATES``\n                                                `here <https://github.com/mlc-ai/mlc-llm/blob/main/python/mlc_llm/model/model.py>`_.\n\n--context-window-size CONTEXT_WINDOW_SIZE       Option to provide the maximum sequence length supported by the model.\n                                                This is usually explicitly shown as context length or context window in the model card.\n                                                If this option is not set explicitly, by default,\n                                                it will be determined by ``context_window_size`` or ``max_position_embeddings`` in ``config.json``,\n                                                and the latter is usually inaccurate for some models.\n\n--sliding-window-size SLIDING_WINDOW            (Experimental) The sliding window size in sliding window attention (SWA).\n                                                This optional field overrides the ``sliding_window`` in ``config.json`` for\n                                                those models that use SWA. Currently only useful when compiling mistral-based models.\n                                                This flag subjects to future refactoring.\n\n--prefill-chunk-size PREFILL_CHUNK_SIZE         (Experimental) The chunk size during prefilling. By default,\n                                                the chunk size is the same as ``context_window_size`` or ``sliding_window_size``.\n                                                This flag subjects to future refactoring.\n\n--tensor-parallel-shard TENSOR_PARALLEL_SHARDS  Number of shards to split the model into in tensor parallelism multi-gpu inference.\n\n--output OUTPUT                                 The output directory for generated configurations, including `mlc-chat-config.json` and tokenizer configuration.\n\n3. Compile Model Library\n^^^^^^^^^^^^^^^^^^^^^^^^\n\nAfter generating ``mlc-chat-config.json``, we can compile the model into a model library (files ending in ``.so``, ``.tar``, etc. that contains\nthe inference logic of a model).\n\nModel compilation command follows the pattern below:\n\n.. code:: text\n\n    mlc_llm compile \\\n        MODEL \\\n        [--quantization QUANTIZATION_MODE] \\\n        [--model-type MODEL_TYPE] \\\n        [--device DEVICE] \\\n        [--host HOST] \\\n        [--opt OPT] \\\n        [--system-lib-prefix SYSTEM_LIB_PREFIX] \\\n        --output OUTPUT \\\n        [--overrides OVERRIDES]\n\nNote that ``MODEL`` is a positional argument. Arguments wrapped with ``[ ]`` are optional.\n\n--MODEL                                     A path to ``mlc-chat-config.json``, or an MLC model directory that contains ``mlc-chat-config.json``.\n\n--quantization QUANTIZATION_MODE            The quantization mode we use to compile. If unprovided, will infer from ``MODEL``.\n\n                                            See :ref:`quantization_mode` for more information.\n                                            Available options are: ``q0f16``, ``q0f32``, ``q3f16_1``, ``q4f16_1``, ``q4f32_1``, and\n                                            ``q4f16_awq``.\n\n                                            We encourage you to use 4-bit quantization, as the text generated by 3-bit\n                                            quantized models may have bad quality depending on the model.\n\n--model-type MODEL_TYPE                     Model architecture such as \"llama\". If not set, it is inferred from ``mlc-chat-config.json``.\n\n--device DEVICE                             The GPU device to compile the model to. If not set, it is inferred from GPUs available locally.\n\n--host HOST                                 The host LLVM triple to compile the model to. If not set, it is inferred from the local CPU and OS.\n                                            Examples of the LLVM triple:\n\n                                            1) iPhones: arm64-apple-ios;\n                                            2) ARM64 Android phones: aarch64-linux-android;\n                                            3) WebAssembly: wasm32-unknown-unknown-wasm;\n                                            4) Windows: x86_64-pc-windows-msvc;\n                                            5) ARM macOS: arm64-apple-darwin.\n\n--opt OPT                                   Optimization flags. MLC LLM maintains a predefined set of optimization flags,\n                                            denoted as ``O0``, ``O1``, ``O2``, ``O3``, where ``O0`` means no optimization, ``O2``\n                                            means majority of them, and ``O3`` represents extreme optimization that could\n                                            potentially break the system.\n\n                                            Meanwhile, optimization flags could be explicitly specified via details knobs, e.g.\n                                            ``--opt=\"cutlass_attn=1;cutlass_norm=0;cublas_gemm=0;cudagraph=0\"``.\n\n--system-lib-prefix SYSTEM_LIB_PREFIX       Adding a prefix to all symbols exported. Similar to ``objcopy --prefix-symbols``.\n                                            This is useful when compiling multiple models into a single library to avoid symbol\n                                            conflicts. Different from objcopy, this takes no effect for shared library.\n\n\n--output OUTPUT                             The path to the output file. The suffix determines if the output file is a shared library or\n                                            objects. Available suffixes:\n\n                                            1) Linux: .so (shared), .tar (objects);\n                                            2) macOS: .dylib (shared), .tar (objects);\n                                            3) Windows: .dll (shared), .tar (objects);\n                                            4) Android, iOS: .tar (objects);\n                                            5) Web: .wasm (web assembly).\n\n--overrides OVERRIDES                       Model configuration override. Configurations to override ``mlc-chat-config.json``. Supports\n                                            ``context_window_size``, ``prefill_chunk_size``, ``sliding_window``, ``max_batch_size`` and\n                                            ``tensor_parallel_shards``. Meanwhile, model config could be explicitly specified via details\n                                            knobs, e.g. ``--overrides \"context_window_size=1024;prefill_chunk_size=128\"``.\n"
  },
  {
    "path": "docs/compilation/configure_quantization.rst",
    "content": "Configure Quantization\n======================\n\nQuantization Algorithm\n----------------------\n\nThe default quantization algorithm used in MLC-LLM is grouping quantization method discussed in the papers `The case for 4-bit precision: k-bit Inference Scaling Laws <https://arxiv.org/abs/2212.09720>`__ and `LUT-GEMM: Quantized Matrix Multiplication based on LUTs for Efficient Inference in Large-Scale Generative Language Models <https://arxiv.org/abs/2206.09557>`__.\n\n.. _quantization_mode:\n\nQuantization Mode\n-----------------\n\nIn MLC-LLM we use a short code that indicates the quantization mode to use. MLC-LLM supports both\nweight-only quantization and weight-activation quantization.\n\nFor the weight-only quantization, he format of the code is ``qAfB(_id)``, where ``A`` represents the number\nof bits for storing weights and ``B`` represents the number of bits for storing activations.\nThe ``_id`` is an integer identifier to distinguish different quantization algorithms (e.g. symmetric, non-symmetric, AWQ, etc).\n\nCurrently, available options are: ``q0f16``, ``q0f32``, ``q3f16_1``, ``q4f16_1``, ``q4f32_1``, and ``q4f16_awq`` (not stable).\n\nFor the weight-activation quantization, currently MLC-LLM supports FP8 quantization on CUDA.\nThe available options are: ``e4m3_e4m3_f16`` and ``e5m2_e5m2_f16``. In these modes, both weights and activations are quantized to FP8 format.\nThe output of each layer is in higher precision (FP16) and then requantized to FP8.\n\n.. _calibration:\n\nCalibration\n-----------\n\nFor ``e4m3_e4m3_f16`` quantization, we need to calibrate the quantization parameters for the activations.\nThe calibration process is done by running the following command:\n\n1. Compile the calibration model\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nWe use the same compilation workflow to compile the model in calibration mode.\nThe only difference is that we need to specify the quantization mode as ``e4m3_e4m3_f16_calibrate``.\n\n.. code-block:: bash\n\n    mlc_llm gen_config \\\n        <model-path> \\\n        --quantization e4m3_e4m3_f16_max_calibrate \\\n        --output <output-path>\n\n    mlc_llm convert_weights \\\n        <model-path> \\\n        --quantization e4m3_e4m3_f16_max_calibrate \\\n        --output <output-path>\n\n    mlc_llm compile \\\n        <config-path> \\\n        --output <output-path>\n\n2. Run the calibration model\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nWe will run the calibration model on the dataset such as ShareGPT to collect the statistics of the\nactivations. The calibration model will updates the quantization parameters in the weights file\nin-place. We turn off the cuda graph as it is not yet supported in the calibration process.\n\n.. code-block:: bash\n\n   mlc_llm calibrate \\\n       <model-path> \\\n       --model-lib <model-lib-path> \\\n       --dataset <dataset-path> \\\n       --num-calibration-samples <num-samples> \\\n       --opt \"cudagraph=0\"\n       --output <output-path>\n\n3. Compile the quantized model for inference.\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nAfter the calibration process, we can compile the model for inference. In this step, we only need\nto generate the configuration file using the desired quantization format and compile the model.\nWeights are already quantized and calibrated in the previous steps and do not need to be converted again.\n\n.. code-block:: bash\n\n    mlc_llm gen_config \\\n        <model-path> \\\n        --quantization e4m3_e4m3_f16 \\\n        --output <output-path>\n    mlc_llm compile \\\n        <config-path> \\\n        --output <output-path>\n"
  },
  {
    "path": "docs/compilation/convert_weights.rst",
    "content": ".. _convert-weights-via-MLC:\n\nConvert Model Weights\n=====================\n\nTo run a model with MLC LLM,\nwe need to convert model weights into MLC format (e.g. `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC <https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/tree/main>`_.)\nThis page walks us through the process of adding a model variant with ``mlc_llm convert_weight``, which\ntakes a huggingface model as input and converts/quantizes into MLC-compatible weights.\n\nSpecifically, we add RedPjama-INCITE-**Instruct**-3B-v1, while MLC already\nprovides a model library for RedPjama-INCITE-**Chat**-3B-v1, which we can reuse.\n\nThis can be extended to, e.g.:\n\n- Add ``OpenHermes-Mistral`` when MLC already supports Mistral\n- Add ``Llama-2-uncensored`` when MLC already supports Llama-2\n\n.. note::\n    Before you proceed, make sure you followed :ref:`install-tvm`, a required\n    backend to compile models with MLC LLM.\n\n    Please also follow the instructions in :ref:`deploy-cli` / :ref:`deploy-python-engine` to obtain\n    the CLI app / Python API that can be used to chat with the compiled model.\n\n\n.. contents:: Table of Contents\n    :depth: 1\n    :local:\n\n.. _verify_installation_for_compile:\n\n1. Verify installation\n----------------------\n\n**Step 1. Verify mlc_llm**\n\nWe use the python package ``mlc_llm`` to compile models. This can be installed by\nfollowing :ref:`install-mlc-packages`, either by building from source, or by\ninstalling the prebuilt package. Verify ``mlc_llm`` installation in command line via:\n\n.. code:: bash\n\n    $ mlc_llm --help\n    # You should see help information with this line\n    usage: MLC LLM Command Line Interface. [-h] {compile,convert_weight,gen_config}\n\n.. note::\n    If it runs into error ``command not found: mlc_llm``, try ``python -m mlc_llm --help``.\n\n**Step 2. Verify TVM**\n\nTo compile models, you also need to follow :ref:`install-tvm`.\nHere we verify ``tvm`` quickly with command line (for full verification, see :ref:`tvm-validate`):\n\n.. code:: bash\n\n    $ python -c \"import tvm; print(tvm.__file__)\"\n    /some-path/lib/python3.13/site-packages/tvm/__init__.py\n\n\n1. Clone from HF and convert_weight\n-----------------------------------\n\nYou can be under the mlc-llm repo, or your own working directory. Note that all platforms\ncan share the same compiled/quantized weights. See :ref:`compile-command-specification`\nfor specification of ``convert_weight``.\n\n.. code:: shell\n\n    # Create directory\n    mkdir -p dist/models && cd dist/models\n    # Clone HF weights\n    git lfs install\n    git clone https://huggingface.co/togethercomputer/RedPajama-INCITE-Instruct-3B-v1\n    cd ../..\n    # Convert weight\n    mlc_llm convert_weight ./dist/models/RedPajama-INCITE-Instruct-3B-v1/ \\\n        --quantization q4f16_1 \\\n        -o dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC\n\n.. _generate_mlc_chat_config:\n\n2. Generate MLC Chat Config\n---------------------------\n\nUse ``mlc_llm gen_config`` to generate ``mlc-chat-config.json`` and process tokenizers.\nSee :ref:`compile-command-specification` for specification of ``gen_config``.\n\n.. code:: shell\n\n    mlc_llm gen_config ./dist/models/RedPajama-INCITE-Instruct-3B-v1/ \\\n        --quantization q4f16_1 --conv-template redpajama_chat \\\n        -o dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC/\n\n\n.. note::\n    The file ``mlc-chat-config.json`` is crucial in both model compilation\n    and runtime chatting. Here we only care about the latter case.\n\n    You can **optionally** customize\n    ``dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC/mlc-chat-config.json`` (checkout :ref:`configure-mlc-chat-json` for more detailed instructions).\n    You can also simply use the default configuration.\n\n    `conversation_template <https://github.com/mlc-ai/mlc-llm/blob/main/python/mlc_llm/conversation_template>`__\n    directory contains a full list of conversation templates that MLC provides. If the model you are adding\n    requires a new conversation template, you would need to add your own.\n    Follow `this PR <https://github.com/mlc-ai/mlc-llm/pull/2163>`__ as an example. However,\n    adding your own template would require you :ref:`build mlc_llm from source <mlcchat_build_from_source>` in order for it\n    to be recognized by the runtime.\n\nBy now, you should have the following files.\n\n.. code:: shell\n\n    ~/mlc-llm > ls dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC\n        mlc-chat-config.json                             # ===> the chat config\n        tensor-cache.json                               # ===> the model weight info\n        params_shard_0.bin                               # ===> the model weights\n        params_shard_1.bin\n        ...\n        tokenizer.json                                   # ===> the tokenizer files\n        tokenizer_config.json\n\n.. _distribute-compiled-models:\n\n(Optional) 3. Upload weights to HF\n----------------------------------\n\nOptionally, you can upload what we have to huggingface.\n\n.. code:: shell\n\n    # First, please create a repository on Hugging Face.\n    # With the repository created, run\n    git lfs install\n    git clone https://huggingface.co/my-huggingface-account/my-redpajama3b-weight-huggingface-repo\n    cd my-redpajama3b-weight-huggingface-repo\n    cp path/to/mlc-llm/dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC/* .\n    git add . && git commit -m \"Add redpajama-3b instruct model weights\"\n    git push origin main\n\nThis would result in something like `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\n<https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/tree/main>`_, but\nfor **Instruct** instead of **Chat**.\n\nGood job, you have successfully distributed the model you compiled.\nNext, we will talk about how we can consume the model weights in applications.\n\nDownload the Distributed Models\n-------------------------------\n\nYou can now use the existing mlc tools such as chat/serve/package with the converted weights.\n\n.. code:: shell\n\n    mlc_llm chat HF://my-huggingface-account/my-redpajama3b-weight-huggingface-repo\n"
  },
  {
    "path": "docs/compilation/define_new_models.rst",
    "content": "Define New Model Architectures\n==============================\n\nThis page guides you how to add a new model architecture in MLC.\n\nThis notebook (runnable in Colab) should contain all necessary information to add a model in\nMLC LLM:\nhttps://github.com/mlc-ai/notebooks/blob/main/mlc-llm/tutorial_add_new_model_architecture_in_tvm_nn_module.ipynb\n\nIn the notebook, we leverage ``tvm.nn.module`` to define a model in MLC LLM. We also use ``JIT``\n(just-in-time compilation) to debug the implementation.\n\nYou can also refer to the PRs below on specific examples of adding a model architecture in MLC LLM:\n\n- `GPTNeoX PR <https://github.com/mlc-ai/mlc-llm/pull/1408>`_\n- `GPT-2 PR <https://github.com/mlc-ai/mlc-llm/pull/1314>`_\n- `Mistral PR <https://github.com/mlc-ai/mlc-llm/pull/1230>`_\n\n.. note::\n\n    When adding a model variant that has\n    its architecture already supported in mlc-llm , you **only need to convert weights**\n    (e.g. adding ``CodeLlama`` when MLC supports ``llama-2``; adding ``OpenHermes Mistral``\n    when MLC supports ``mistral``). On the other hand, a new model architecture\n    (or inference logic) requires more work (following the tutorial above).\n"
  },
  {
    "path": "docs/compilation/package_libraries_and_weights.rst",
    "content": ".. _package-libraries-and-weights:\n\nPackage Libraries and Weights\n=============================\n\nWhen we want to build LLM applications with MLC LLM (e.g., iOS/Android apps),\nusually we need to build static model libraries and app binding libraries,\nand sometimes bundle model weights into the app.\nMLC LLM provides a tool for fast model library and weight packaging: ``mlc_llm package``.\n\nThis page briefly introduces how to use ``mlc_llm package`` for packaging.\nTutorials :ref:`deploy-ios` and :ref:`deploy-android` contain detailed examples and instructions\non using this packaging tool for iOS and Android deployment.\n\n-----\n\nIntroduction\n------------\n\nTo use ``mlc_llm package``, we must clone the source code of `MLC LLM <https://github.com/mlc-ai/mlc-llm>`_\nand `install the MLC LLM and TVM package <https://llm.mlc.ai/docs/install/mlc_llm.html#option-1-prebuilt-package>`_.\nDepending on the app we build, there might be some other dependencies, which are described in\ncorresponding :ref:`iOS <deploy-ios>` and :ref:`Android <deploy-android>` tutorials.\n\nAfter cloning, the basic usage of ``mlc_llm package`` is as the following.\n\n.. code:: bash\n\n    export MLC_LLM_SOURCE_DIR=/path/to/mlc-llm\n    cd /path/to/app  # The app root directory which contains \"mlc-package-config.json\".\n                     # E.g., \"ios/MLCChat\" or \"android/MLCChat\"\n    mlc_llm package\n\n**The package command reads from the JSON file** ``mlc-package-config.json`` **under the current directory.**\nThe output of this command is a directory ``dist/``,\nwhich contains the packaged model libraries (under ``dist/lib/``) and weights (under ``dist/bundle/``).\nThis directory contains all necessary data for the app build.\nDepending on the app we build, the internal structure of ``dist/lib/`` may be different.\n\n.. code::\n\n   dist\n   ├── lib\n   │   └── ...\n   └── bundle\n       └── ...\n\nThe input ``mlc-package-config.json`` file specifies\n\n* the device (e.g., iPhone or Android) to package model libraries and weights for,\n* the list of models to package.\n\nBelow is an example ``mlc-package-config.json`` file:\n\n.. code:: json\n\n    {\n        \"device\": \"iphone\",\n        \"model_list\": [\n            {\n                \"model\": \"HF://mlc-ai/Mistral-7B-Instruct-v0.2-q3f16_1-MLC\",\n                \"model_id\": \"Mistral-7B-Instruct-v0.2-q3f16_1\",\n                \"estimated_vram_bytes\": 3316000000,\n                \"bundle_weight\": true,\n                \"overrides\": {\n                    \"context_window_size\": 512\n                }\n            },\n            {\n                \"model\": \"HF://mlc-ai/gemma-2b-it-q4f16_1-MLC\",\n                \"model_id\": \"gemma-2b-q4f16_1\",\n                \"estimated_vram_bytes\": 3000000000,\n                \"overrides\": {\n                    \"prefill_chunk_size\": 128\n                }\n            }\n        ]\n    }\n\nThis example ``mlc-package-config.json`` specifies \"iphone\" as the target device.\nIn the ``model_list``,\n\n* ``model`` points to the Hugging Face repository which contains the pre-converted model weights. Apps will download model weights from the Hugging Face URL.\n* ``model_id`` is a unique model identifier.\n* ``estimated_vram_bytes`` is an estimation of the vRAM the model takes at runtime.\n* ``\"bundle_weight\": true`` means the model weights of the model will be bundled into the app when building.\n* ``overrides`` specifies some model config parameter overrides.\n\n\nBelow is a more detailed specification of the ``mlc-package-config.json`` file.\nEach entry in ``\"model_list\"`` of the JSON file has the following fields:\n\n``model``\n   (Required) The path to the MLC-converted model to be built into the app.\n\n   Usually it is a Hugging Face URL (e.g., ``\"model\": \"HF://mlc-ai/phi-2-q4f16_1-MLC\"```) that contains the pre-converted model weights.\n   For iOS, it can also be a path to a local model directory which contains converted model weights (e.g., ``\"model\": \"../dist/gemma-2b-q4f16_1\"``).\n   Please check out :ref:`convert-weights-via-MLC` if you want to build local model into the app.\n\n``model_id``\n  (Required) A unique local identifier to identify the model.\n  It can be an arbitrary one.\n\n``estimated_vram_bytes``\n   (Required) Estimated requirements of vRAM to run the model.\n\n``bundle_weight``\n   (Optional) A boolean flag indicating whether to bundle model weights into the app.\n   If this field is set to true, the ``mlc_llm package`` command will copy the model weights\n   to ``dist/bundle/$model_id``.\n\n``overrides``\n   (Optional) A dictionary to override the default model context window size (to limit the KV cache size) and prefill chunk size (to limit the model temporary execution memory).\n   Example:\n\n   .. code:: json\n\n      {\n         \"device\": \"iphone\",\n         \"model_list\": [\n            {\n                  \"model\": \"HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\",\n                  \"model_id\": \"RedPajama-INCITE-Chat-3B-v1-q4f16_1\",\n                  \"estimated_vram_bytes\": 2960000000,\n                  \"overrides\": {\n                     \"context_window_size\": 512,\n                     \"prefill_chunk_size\": 128\n                  }\n            }\n         ]\n      }\n\n``model_lib``\n   (Optional) A string specifying the system library prefix to use for the model.\n   Usually this is used when you want to build multiple model variants with the same architecture into the app.\n   **This field does not affect any app functionality.**\n   The ``\"model_lib_path_for_prepare_libs\"`` introduced below is also related.\n   Example:\n\n   .. code:: json\n\n      {\n         \"device\": \"iphone\",\n         \"model_list\": [\n            {\n                  \"model\": \"HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\",\n                  \"model_id\": \"RedPajama-INCITE-Chat-3B-v1-q4f16_1\",\n                  \"estimated_vram_bytes\": 2960000000,\n                  \"model_lib\": \"gpt_neox_q4f16_1\"\n            }\n         ]\n      }\n\n\nBesides ``model_list`` in ``MLCChat/mlc-package-config.json``,\nyou can also **optionally** specify a dictionary of ``\"model_lib_path_for_prepare_libs\"``,\n**if you want to use model libraries that are manually compiled**.\nThe keys of this dictionary should be the ``model_lib`` that specified in model list,\nand the values of this dictionary are the paths (absolute, or relative) to the manually compiled model libraries.\nThe model libraries specified in ``\"model_lib_path_for_prepare_libs\"`` will be built into the app when running ``mlc_llm package``.\nExample:\n\n.. code:: json\n\n   {\n      \"device\": \"iphone\",\n      \"model_list\": [\n         {\n               \"model\": \"HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\",\n               \"model_id\": \"RedPajama-INCITE-Chat-3B-v1-q4f16_1\",\n               \"estimated_vram_bytes\": 2960000000,\n               \"model_lib\": \"gpt_neox_q4f16_1\"\n         }\n      ],\n      \"model_lib_path_for_prepare_libs\": {\n         \"gpt_neox_q4f16_1\": \"../../dist/lib/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar\"\n      }\n   }\n\nCompilation Cache\n-----------------\n``mlc_llm package`` leverage a local JIT cache to avoid repetitive compilation of the same input.\nIt also leverages a local cache to download weights from remote. These caches\nare shared across the entire project. Sometimes it is helpful to force rebuild when\nwe have a new compiler update or when something goes wrong with the cached library.\nYou can do so by setting the environment variable ``MLC_JIT_POLICY=REDO``\n\n.. code:: bash\n\n   MLC_JIT_POLICY=REDO mlc_llm package\n\nArguments of ``mlc_llm package``\n--------------------------------\n\nCommand ``mlc_llm package`` can optionally take the arguments below:\n\n``--package-config``\n    A path to ``mlc-package-config.json`` which contains the device and model specification.\n    By default, it is the ``mlc-package-config.json`` under the current directory.\n\n``--mlc-llm-source-dir``\n    The path to MLC LLM source code (cloned from https://github.com/mlc-ai/mlc-llm).\n    By default, it is the ``$MLC_LLM_SOURCE_DIR`` environment variable.\n    If neither ``$MLC_LLM_SOURCE_DIR`` or ``--mlc-llm-source-dir`` is specified, error will be reported.\n\n``--output`` / ``-o``\n    The output directory of ``mlc_llm package`` command.\n    By default, it is ``dist/`` under the current directory.\n\n\nSummary and What to Do Next\n---------------------------\n\nIn this page, we introduced the ``mlc_llm package`` command for fast model library and weight packaging.\n\n* It takes input file ``mlc-package-config.json`` which contains the device and model specification for packaging.\n* It outputs directory ``dist/``, which contains packaged libraries under ``dist/lib/`` and model weights under ``dist/bundle/``.\n\nNext, please feel free to check out the :ref:`iOS <deploy-ios>` and :ref:`Android <deploy-android>` tutorials for detailed examples of using ``mlc_llm package``.\n"
  },
  {
    "path": "docs/conf.py",
    "content": "# -*- coding: utf-8 -*-\nimport os\nimport sys\n\nimport tlcpack_sphinx_addon\n\n# -- General configuration ------------------------------------------------\n\nsys.path.insert(0, os.path.abspath(\"../python\"))\nsys.path.insert(0, os.path.abspath(\"../\"))\nautodoc_mock_imports = [\"torch\"]\n\n# General information about the project.\nproject = \"mlc-llm\"\nauthor = \"MLC LLM Contributors\"\ncopyright = \"2023-2025, %s\" % author\n\n# Version information.\n\nversion = \"0.1.0\"\nrelease = \"0.1.0\"\n\nextensions = [\n    \"sphinx_tabs.tabs\",\n    \"sphinx_toolbox.collapse\",\n    \"sphinxcontrib.httpdomain\",\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.napoleon\",\n    \"sphinx_reredirects\",\n]\n\nredirects = {\"get_started/try_out\": \"../index.html#getting-started\"}\n\nsource_suffix = [\".rst\"]\n\nlanguage = \"en\"\n\nexclude_patterns = [\"_build\", \"Thumbs.db\", \".DS_Store\"]\n\n# The name of the Pygments (syntax highlighting) style to use.\npygments_style = \"sphinx\"\n\n# A list of ignored prefixes for module index sorting.\n# If true, `todo` and `todoList` produce output, else they produce nothing.\ntodo_include_todos = False\n\n# -- Options for HTML output ----------------------------------------------\n\n# The theme is set by the make target\nimport sphinx_rtd_theme\n\nhtml_theme = \"sphinx_rtd_theme\"\nhtml_theme_path = [sphinx_rtd_theme.get_html_theme_path()]\n\ntemplates_path = []\n\nhtml_static_path = []\n\nfooter_copyright = \"© 2023-2025 MLC LLM\"\nfooter_note = \" \"\n\nhtml_logo = \"_static/img/mlc-logo-with-text-landscape.svg\"\n\nhtml_theme_options = {\n    \"logo_only\": True,\n}\n\nheader_links = [\n    (\"Home\", \"https://llm.mlc.ai/\"),\n    (\"Github\", \"https://github.com/mlc-ai/mlc-llm\"),\n    (\"Discord Server\", \"https://discord.gg/9Xpy2HGBuD\"),\n]\n\nheader_dropdown = {\n    \"name\": \"Other Resources\",\n    \"items\": [\n        (\"MLC Course\", \"https://mlc.ai/\"),\n        (\"MLC Blog\", \"https://blog.mlc.ai/\"),\n        (\"Web LLM\", \"https://webllm.mlc.ai/\"),\n    ],\n}\n\nhtml_context = {\n    \"footer_copyright\": footer_copyright,\n    \"footer_note\": footer_note,\n    \"header_links\": header_links,\n    \"header_dropdown\": header_dropdown,\n    \"display_github\": True,\n    \"github_user\": \"mlc-ai\",\n    \"github_repo\": \"mlc-llm\",\n    \"github_version\": \"main/docs/\",\n    \"theme_vcs_pageview_mode\": \"edit\",\n    # \"header_logo\": \"/path/to/logo\",\n    # \"header_logo_link\": \"\",\n    # \"version_selecter\": \"\",\n}\n\n\n# add additional overrides\ntemplates_path += [tlcpack_sphinx_addon.get_templates_path()]\nhtml_static_path += [tlcpack_sphinx_addon.get_static_path()]\n"
  },
  {
    "path": "docs/deploy/android.rst",
    "content": ".. _deploy-android:\n\nAndroid SDK\n===========\n\n.. contents:: Table of Contents\n   :local:\n   :depth: 2\n\nDemo App\n--------\n\nThe demo APK below is built for Samsung S23 with Snapdragon 8 Gen 2 chip.\n\n.. image:: https://seeklogo.com/images/D/download-android-apk-badge-logo-D074C6882B-seeklogo.com.png\n  :width: 135\n  :target: https://github.com/mlc-ai/binary-mlc-llm-libs/releases/download/Android-09262024/mlc-chat.apk\n\nPrerequisite\n------------\n\n**Rust** (`install <https://www.rust-lang.org/tools/install>`__) is needed to cross-compile HuggingFace tokenizers to Android. Make sure rustc, cargo, and rustup are available in ``$PATH``.\n\n**Android Studio** (`install <https://developer.android.com/studio>`__) with NDK and CMake. To install NDK and CMake, on the Android Studio welcome page, click \"Projects → SDK Manager → SDK Tools\". If you have already installed NDK in your development environment, please update your NDK to avoid build android package fail(`#2696 <https://github.com/mlc-ai/mlc-llm/issues/2696>`__). The current demo Android APK is built with NDK 27.0.11718014. Once you have installed or updated the NDK, set up the following environment variables:\n\n\n- ``ANDROID_NDK`` so that ``$ANDROID_NDK/build/cmake/android.toolchain.cmake`` is available.\n- ``TVM_NDK_CC`` that points to NDK's clang compiler.\n\n.. code-block:: bash\n\n  # Example on macOS\n  ANDROID_NDK: $HOME/Library/Android/sdk/ndk/27.0.11718014\n  TVM_NDK_CC: $ANDROID_NDK/toolchains/llvm/prebuilt/darwin-x86_64/bin/aarch64-linux-android24-clang\n  # Example on Linux\n  ANDROID_NDK: $HOME/Android/Sdk/ndk/27.0.11718014\n  TVM_NDK_CC: $ANDROID_NDK/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android24-clang\n  # Example on Windows\n  ANDROID_NDK: %HOME%/AppData/Local/Android/Sdk/ndk/27.0.11718014\n  TVM_NDK_CC: %ANDROID_NDK%/toolchains/llvm/prebuilt/windows-x86_64/bin/aarch64-linux-android24-clang\n\n**JDK**, such as OpenJDK >= 17, to compile Java bindings of TVM runtime.\nWe strongly recommend setting the ``JAVA_HOME`` to the JDK bundled with Android Studio.\ne.g.\n``export JAVA_HOME=/Applications/Android\\ Studio.app/Contents/jbr/Contents/Home`` for macOS.\n``export JAVA_HOME=/opt/android-studio/jbr`` for Linux.\nUsing Android Studio's JBR bundle as recommended `here https://developer.android.com/build/jdks`\nwill reduce the chances of potential errors in JNI compilation.\nSet up the following environment variable:\n\n- ``export JAVA_HOME=/path/to/java_home`` you can then cross check and make sure ``$JAVA_HOME/bin/java`` exists.\n\nPlease ensure that the JDK versions for Android Studio and JAVA_HOME are the same.\n\n**TVM runtime** is placed under `3rdparty/tvm <https://github.com/mlc-ai/mlc-llm/tree/main/3rdparty>`__ in MLC LLM, so there is no need to install anything extra. Set up the following environment variable:\n\n- ``export TVM_SOURCE_DIR=/path/to/mlc-llm/3rdparty/tvm``.\n\nPlease follow :doc:`/install/mlc_llm` to obtain a binary build of mlc_llm package. Note that this\nis independent from mlc-llm source code that we use for android package build in the following up section.\nOnce you installed this package, you do not need to build mlc llm from source.\n\n.. note::\n    ❗ Whenever using Python, it is highly recommended to use **conda** to manage an isolated Python environment to avoid missing dependencies, incompatible versions, and package conflicts.\n\nCheck if **environment variable** are properly set as the last check. One way to ensure this is to place them in ``$HOME/.zshrc``, ``$HOME/.bashrc`` or environment management tools.\n\n.. code-block:: bash\n\n  source $HOME/.cargo/env # Rust\n  export ANDROID_NDK=...  # Android NDK toolchain\n  export TVM_NDK_CC=...   # Android NDK clang\n  export JAVA_HOME=...    # Java\n  export TVM_SOURCE_DIR=...     # TVM runtime\n\nAdditional Guides for Windows Users\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nBuilding under Windows for Android is still experimental; please make sure you\nfirst finish the above guides, then read and follow the instructions in this section\nIf you are using Windows, make sure you use conda to install cmake and Ninja.\n\n.. code-block:: bash\n\n    conda install -c conda-forge cmake ninja git git-lfs zstd\n\nWindows Java findings have issues with environment variables that come with space.\nMake sure you get a copy of Java in a path without space. The simplest way to do that\nis to copy the Android Studio's JBR bundle to a directory without any space.\nIf your Android studio's installation is at ``C:\\Program Files\\Android\\Android Studio\\``\nyou can try to do the following\n\n.. code-block:: bash\n\n   cp -r \"C:\\Program Files\\Android\\Android Studio\\jbr\" C:\\any-path-without-space\n   set JAVA_HOME=C:\\any-path-without-space\n\nYou can continue the next steps after you have set these steps correctly.\n\nBuild Android App from Source\n-----------------------------\n\nThis section shows how we can build the app from the source.\n\nStep 1. Install Build Dependencies\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nFirst and foremost, please clone the `MLC LLM GitHub repository <https://github.com/mlc-ai/mlc-llm>`_.\nAfter cloning, go to the ``android/`` directory.\n\n.. code:: bash\n\n   git clone https://github.com/mlc-ai/mlc-llm.git\n   cd mlc-llm\n   git submodule update --init --recursive\n   cd android\n\n\n.. _android-build-runtime-and-model-libraries:\n\nStep 2. Build Runtime and Model Libraries\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nThe models to be built for the Android app are specified in ``MLCChat/mlc-package-config.json``:\nin the ``model_list``, ``model`` points to the Hugging Face repository which\n\n* ``model`` points to the Hugging Face repository which contains the pre-converted model weights. The Android app will download model weights from the Hugging Face URL.\n* ``model_id`` is a unique model identifier.\n* ``estimated_vram_bytes`` is an estimation of the vRAM the model takes at runtime.\n* ``\"bundle_weight\": true`` means the model weights of the model will be bundled into the app when building.\n* ``overrides`` specifies some model config parameter overrides.\n\n\nWe have a one-line command to build and prepare all the model libraries:\n\n.. code:: bash\n\n   cd /path/to/MLCChat  # e.g., \"android/MLCChat\"\n   export MLC_LLM_SOURCE_DIR=/path/to/mlc-llm  # has to be absolute path, ../.. does not work\n   mlc_llm package\n\nThis command mainly executes the following two steps:\n\n1. **Compile models.** We compile each model in ``model_list`` of ``MLCChat/mlc-package-config.json`` into a binary model library.\n2. **Build runtime and tokenizer.** In addition to the model itself, a lightweight runtime and tokenizer are required to actually run the LLM.\n\nThe command creates a ``./dist/`` directory that contains the runtime and model build output.\nPlease make sure all the following files exist in ``./dist/``.\n\n.. code::\n\n   dist\n   └── lib\n       └── mlc4j\n           ├── build.gradle\n           ├── output\n           │   ├── arm64-v8a\n           │   │   └── libtvm4j_runtime_packed.so\n           │   └── tvm4j_core.jar\n           └── src\n               ├── cpp\n               │   └── tvm_runtime.h\n               └── main\n                   ├── AndroidManifest.xml\n                   ├── assets\n                   │   └── mlc-app-config.json\n                   └── java\n                       └── ...\n\nThe model execution logic in mobile GPUs is incorporated into ``libtvm4j_runtime_packed.so``,\nwhile ``tvm4j_core.jar`` is a lightweight (~60 kb) `Java binding <https://tvm.apache.org/docs/reference/api/javadoc/>`_\nto it. ``dist/lib/mlc4j`` is a gradle subproject that you should include in your app\nso the Android project can reference the mlc4j (MLC LLM java library).\nThis library packages the dependent model libraries and necessary runtime to execute the model.\n\n.. code::\n\n   include ':mlc4j'\n   project(':mlc4j').projectDir = file('dist/lib/mlc4j')\n\n\n.. note::\n\n   We leverage a local JIT cache to avoid repetitive compilation of the same input.\n   However, sometimes it is helpful to force rebuild when we have a new compiler update\n   or when something goes wrong with the cached library.\n   You can do so by setting the environment variable ``MLC_JIT_POLICY=REDO``\n\n   .. code:: bash\n\n      MLC_JIT_POLICY=REDO mlc_llm package\n\n\nStep 3. Build Android App\n^^^^^^^^^^^^^^^^^^^^^^^^^\n\nOpen folder ``./android/MLCChat`` as an Android Studio Project.\nConnect your Android device to your machine.\nIn the menu bar of Android Studio, click **\"Build → Make Project\"**.\nOnce the build is finished, click **\"Run → Run 'app'\"** and you will see the app launched on your phone.\n\n.. note::\n    ❗ This app cannot be run in an emulator and thus a physical phone is required, because MLC LLM needs an actual mobile GPU to meaningfully run at an accelerated speed.\n\n\nCustomize the App\n-----------------\n\nWe can customize the models built in the Android app by customizing `MLCChat/mlc-package-config.json <https://github.com/mlc-ai/mlc-llm/blob/main/android/MLCChat/mlc-package-config.json>`_.\nWe introduce each field of the JSON file here.\n\nEach entry in ``\"model_list\"`` of the JSON file has the following fields:\n\n``model``\n   (Required) The path to the MLC-converted model to be built into the app.\n   It is a Hugging Face URL (e.g., ``\"model\": \"HF://mlc-ai/phi-2-q4f16_1-MLC\"```) that contains\n   the pre-converted model weights.\n\n``model_id``\n  (Required) A unique local identifier to identify the model.\n  It can be an arbitrary one.\n\n``estimated_vram_bytes``\n   (Required) Estimated requirements of vRAM to run the model.\n\n``bundle_weight``\n   (Optional) A boolean flag indicating whether to bundle model weights into the app. See :ref:`android-bundle-model-weights` below.\n\n``overrides``\n   (Optional) A dictionary to override the default model context window size (to limit the KV cache size) and prefill chunk size (to limit the model temporary execution memory).\n   Example:\n\n   .. code:: json\n\n      {\n         \"device\": \"android\",\n         \"model_list\": [\n            {\n                  \"model\": \"HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\",\n                  \"model_id\": \"RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\",\n                  \"estimated_vram_bytes\": 1948348579,\n                  \"overrides\": {\n                     \"context_window_size\": 512,\n                     \"prefill_chunk_size\": 128\n                  }\n            }\n         ]\n      }\n\n``model_lib``\n   (Optional) A string specifying the system library prefix to use for the model.\n   Usually this is used when you want to build multiple model variants with the same architecture into the app.\n   **This field does not affect any app functionality.**\n   The ``\"model_lib_path_for_prepare_libs\"`` introduced below is also related.\n   Example:\n\n   .. code:: json\n\n      {\n         \"device\": \"android\",\n         \"model_list\": [\n            {\n                  \"model\": \"HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\",\n                  \"model_id\": \"RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\",\n                  \"estimated_vram_bytes\": 1948348579,\n                  \"model_lib\": \"gpt_neox_q4f16_1\"\n            }\n         ]\n      }\n\n\nBesides ``model_list`` in ``MLCChat/mlc-package-config.json``,\nyou can also **optionally** specify a dictionary of ``\"model_lib_path_for_prepare_libs\"``,\n**if you want to use model libraries that are manually compiled**.\nThe keys of this dictionary should be the ``model_lib`` that specified in model list,\nand the values of this dictionary are the paths (absolute, or relative) to the manually compiled model libraries.\nThe model libraries specified in ``\"model_lib_path_for_prepare_libs\"`` will be built into the app when running ``mlc_llm package``.\nExample:\n\n.. code:: json\n\n   {\n      \"device\": \"android\",\n      \"model_list\": [\n         {\n               \"model\": \"HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\",\n               \"model_id\": \"RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\",\n               \"estimated_vram_bytes\": 1948348579,\n               \"model_lib\": \"gpt_neox_q4f16_1\"\n         }\n      ],\n      \"model_lib_path_for_prepare_libs\": {\n         \"gpt_neox_q4f16_1\": \"../../dist/lib/RedPajama-INCITE-Chat-3B-v1-q4f16_1-android.tar\"\n      }\n   }\n\n.. _android-bundle-model-weights:\n\nBundle Model Weights\n--------------------\n\nInstructions have been provided to build an Android App with MLC LLM in previous sections,\nbut it requires run-time weight downloading from HuggingFace,\nas configured in ``MLCChat/mlc-package-config.json``.\nHowever, it could be desirable to bundle weights together into the app to avoid downloading over the network.\nIn this section, we provide a simple ADB-based walkthrough that hopefully helps with further development.\n\n**Enable weight bundle**.\nSet the field ``\"bundle_weight\": true`` for any model you want to bundle weights\nin ``MLCChat/mlc-package-config.json``, and run ``mlc_llm package`` again.\nBelow is an example:\n\n.. code:: json\n\n   {\n      \"device\": \"android\",\n      \"model_list\": [\n         {\n            \"model\": \"HF://mlc-ai/gemma-2b-it-q4f16_1-MLC\",\n            \"model_id\": \"gemma-2b-q4f16_1-MLC\",\n            \"estimated_vram_bytes\": 3000000000,\n            \"bundle_weight\": true\n         }\n      ]\n   }\n\nThe outcome of running ``mlc_llm package`` should be as follows:\n\n.. code::\n\n   dist\n   ├── bundle\n   │   ├── gemma-2b-q4f16_1   # The model weights that will be bundled into the app.\n   │   └── mlc-app-config.json\n   └── ...\n\n\n**Generating APK**. Enter Android Studio, and click **\"Build → Generate Signed Bundle/APK\"** to build an APK for release. If it is the first time you generate an APK, you will need to create a key according to `the official guide from Android <https://developer.android.com/studio/publish/app-signing#generate-key>`_.\nThis APK will be placed under ``android/MLCChat/app/release/app-release.apk``.\n\n**Install ADB and USB debugging**. Enable \"USB debugging\" in the developer mode in your phone settings.\nIn \"SDK manager - SDK Tools\", install `Android SDK Platform-Tools <https://developer.android.com/studio/releases/platform-tools>`_.\nAdd the path to platform-tool path to the environment variable ``PATH`` (on macOS, it is ``$HOME/Library/Android/sdk/platform-tools``).\nRun the following commands, and if ADB is installed correctly, your phone will appear as a device:\n\n.. code-block:: bash\n\n  adb devices\n\n**Install the APK and weights to your phone**.\nRun the commands below to install the app, and push the local weights to the app data directory on your device.\nOnce it finishes, you can start the MLCChat app on your device.\nThe models with ``bundle_weight`` set to true will have their weights already on device.\n\n.. code-block:: bash\n\n  cd /path/to/MLCChat  # e.g., \"android/MLCChat\"\n  python bundle_weight.py --apk-path app/release/app-release.apk\n\nKnown issues\n------------\n\nOne known issue that has been observed on Android devices equipped with Adreno GPUs is that model formats ending with a ``_1`` suffix cause a ~20-50 seconds system UI freeze that occurs at prefill stage (initialization before the first inference; the issue does not happen on any subsequent inference of a given model instance).\nIt has been observed that models with a ``_0`` suffix do not experience this issue.\nThe two suffixes denote the layouts of weights in the models that differ by a transpose operation.\nIn case you encounter the freeze issue, the workaround to avoid this problem is to use a model in the ``_0`` weight layout.\nFor more details, please consult `issue #3363 <https://github.com/mlc-ai/mlc-llm/issues/3363>`_.\n"
  },
  {
    "path": "docs/deploy/cli.rst",
    "content": ".. _deploy-cli:\n\nCLI\n===============\n\nMLC Chat CLI is the command line tool to run MLC-compiled LLMs out of the box interactively.\n\n.. contents:: Table of Contents\n  :local:\n  :depth: 2\n\nInstall MLC-LLM Package\n------------------------\n\nChat CLI is a part of the MLC-LLM package.\nTo use the chat CLI, first install MLC LLM by following the instructions :ref:`here <install-mlc-packages>`.\nOnce you have install the MLC-LLM package, you can run the following command to check if the installation was successful:\n\n.. code:: bash\n\n   mlc_llm chat --help\n\nYou should see serve help message if the installation was successful.\n\nQuick Start\n------------\n\nThis section provides a quick start guide to work with MLC-LLM chat CLI.\nTo launch the CLI session, run the following command:\n\n.. code:: bash\n\n   mlc_llm chat MODEL [--model-lib PATH-TO-MODEL-LIB]\n\nwhere ``MODEL`` is the model folder after compiling with :ref:`MLC-LLM build process <compile-model-libraries>`. Information about other arguments can be found in the next section.\n\nOnce the chat CLI is ready, you can enter the prompt to interact with the model.\n\n.. code::\n\n  You can use the following special commands:\n    /help               print the special commands\n    /exit               quit the cli\n    /stats              print out stats of last request (token/sec)\n    /metrics            print out full engine metrics\n    /reset              restart a fresh chat\n    /set [overrides]    override settings in the generation config. For example,\n                        `/set temperature=0.5;top_p=0.8;seed=23;max_tokens=100;stop=str1,str2`\n                        Note: Separate stop words in the `stop` option with commas (,).\n    Multi-line input: Use escape+enter to start a new line.\n\n  >>> What's the meaning of life?\n  The meaning of life is a philosophical and metaphysical question related to the purpose or significance of life or existence in general...\n\nRun CLI with Multi-GPU\n----------------------\n\nIf you want to enable tensor parallelism to run LLMs on multiple GPUs, please specify argument ``--overrides \"tensor_parallel_shards=$NGPU\"``. For example,\n\n.. code:: shell\n\n  mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC --overrides \"tensor_parallel_shards=2\"\n\n\nThe ``mlc_llm chat`` Command\n----------------------------\n\nWe provide the list of chat CLI interface for reference.\n\n.. code:: bash\n\n   mlc_llm chat MODEL [--model-lib PATH-TO-MODEL-LIB] [--device DEVICE] [--overrides OVERRIDES]\n\n\nMODEL                  The model folder after compiling with MLC-LLM build process. The parameter\n                       can either be the model name with its quantization scheme\n                       (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model\n                       folder. In the former case, we will use the provided name to search\n                       for the model folder over possible paths.\n\n--model-lib            A field to specify the full path to the model library file to use (e.g. a ``.so`` file).\n--device               The description of the device to run on. User should provide a string in the\n                       form of ``device_name:device_id`` or ``device_name``, where ``device_name`` is one of\n                       ``cuda``, ``metal``, ``vulkan``, ``rocm``, ``opencl``, ``auto`` (automatically detect the\n                       local device), and ``device_id`` is the device id to run on. The default value is ``auto``,\n                       with the device id set to 0 for default.\n--overrides            Model configuration override. Supports overriding\n                       ``context_window_size``, ``prefill_chunk_size``, ``sliding_window_size``, ``attention_sink_size``,\n                       and ``tensor_parallel_shards``. The overrides could be explicitly\n                       specified via details knobs, e.g. --overrides ``context_window_size=1024;prefill_chunk_size=128``.\n"
  },
  {
    "path": "docs/deploy/ide_integration.rst",
    "content": ".. _deploy-ide-integration:\n\nIDE Integration\n===============\n\n.. contents:: Table of Contents\n   :local:\n   :depth: 2\n\nMLC LLM has now support for code completion on multiple IDEs. This means you can easily integrate an LLM with coding capabilities with your IDE through the MLC LLM :ref:`deploy-rest-api`. Here we provide a step-by-step guide on how to do this.\n\nConvert Your Model Weights\n--------------------------\n\nTo run a model with MLC LLM in any platform, you need to convert your model weights to the MLC format (e.g. `CodeLlama-7b-hf-q4f16_1-MLC <https://huggingface.co/mlc-ai/CodeLlama-7b-hf-q4f16_1-MLC>`__). You can always refer to :ref:`convert-weights-via-MLC` for in-depth details on how to convert your model weights. If you are using your own model weights, i.e., you finetuned the model on your personal codebase, it is important to follow these steps to convert the respective weights properly. However, it is also possible to download precompiled weights from the original models, available in the MLC format. See the full list of all precompiled weights `here <https://huggingface.co/mlc-ai>`__.\n\n**Example:**\n\n.. code:: bash\n\n   # convert model weights\n   mlc_llm convert_weight ./dist/models/CodeLlama-7b-hf \\\n      --quantization q4f16_1 \\\n      -o ./dist/CodeLlama-7b-hf-q4f16_1-MLC\n\nCompile Your Model\n------------------\n\nCompiling the model architecture is the crucial step to optimize inference for a given platform. However, compilation relies on multiple settings that will impact the runtime. This configuration is specified inside the ``mlc-chat-config.json`` file, which can be generated by the ``gen_config`` command. You can learn more about the ``gen_config`` command `here </docs/compilation/compile_models.html#generate-mlc-chat-config>`__.\n\n**Example:**\n\n.. code:: bash\n\n   # generate mlc-chat-config.json\n   mlc_llm gen_config ./dist/models/CodeLlama-7b-hf \\\n      --quantization q4f16_1 --conv-template LM \\\n      -o ./dist/CodeLlama-7b-hf-q4f16_1-MLC\n\n.. note::\n   Make sure to set the ``--conv-template`` flag to ``LM``. This template is specifically tailored to perform vanilla LLM completion, generally adopted by code completion models.\n\nAfter generating the MLC model configuration file, we are all set to compile and create the model library. You can learn more about the ``compile`` command `here </docs/compilation/compile_models.html#compile-model-library>`__\n\n**Example:**\n\n.. tabs::\n\n   .. group-tab:: Linux - CUDA\n\n      .. code:: bash\n\n         # compile model library with specification in mlc-chat-config.json\n         mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \\\n            --device cuda -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-cuda.so\n\n   .. group-tab:: Metal\n\n      For M-chip Mac:\n\n      .. code:: bash\n\n         # compile model library with specification in mlc-chat-config.json\n         mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \\\n            --device metal -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-metal.so\n\n      Cross-Compiling for Intel Mac on M-chip Mac:\n\n      .. code:: bash\n\n         # compile model library with specification in mlc-chat-config.json\n         mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \\\n            --device metal:x86-64 -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-metal_x86_64.dylib\n\n      For Intel Mac:\n\n      .. code:: bash\n\n         # compile model library with specification in mlc-chat-config.json\n         mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \\\n            --device metal -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-metal_x86_64.dylib\n\n   .. group-tab:: Vulkan\n\n      For Linux:\n\n      .. code:: bash\n\n         # compile model library with specification in mlc-chat-config.json\n         mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \\\n            --device vulkan -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-vulkan.so\n\n      For Windows:\n\n      .. code:: bash\n\n         # compile model library with specification in mlc-chat-config.json\n         mlc_llm compile ./dist/CodeLlama-7b-hf-q4f16_1-MLC/mlc-chat-config.json \\\n            --device vulkan -o ./dist/libs/CodeLlama-7b-hf-q4f16_1-vulkan.dll\n\n.. note::\n   The generated model library can be shared across multiple model variants, as long as the architecture and number of parameters does not change, e.g., same architecture, but different weights (your finetuned model).\n\nSetting up the Inference Entrypoint\n-----------------------------------\n\nYou can now locally deploy your compiled model with the MLC serve module. To find more details about the MLC LLM API visit our :ref:`deploy-rest-api` page.\n\n**Example:**\n\n.. code:: bash\n\n   python -m mlc_llm.serve.server \\\n      --model dist/CodeLlama-7b-hf-q4f16_1-MLC \\\n      --model-lib ./dist/libs/CodeLlama-7b-hf-q4f16_1-cuda.so\n\nConfigure the IDE Extension\n---------------------------\n\nAfter deploying the LLM we can easily connect the IDE with the MLC Rest API. In this guide, we will be using the Hugging Face Code Completion extension `llm-ls <https://github.com/huggingface/llm-ls>`__ which has support across multiple IDEs (e.g., `vscode <https://github.com/huggingface/llm-vscode>`__, `intellij <https://github.com/huggingface/llm-intellij>`__ and `nvim <https://github.com/huggingface/llm.nvim>`__) to connect to an external OpenAI compatible API (i.e., our MLC LLM :ref:`deploy-rest-api`).\n\nAfter installing the extension on your IDE, open the ``settings.json`` extension configuration file:\n\n.. figure:: /_static/img/ide_code_settings.png\n   :width: 450\n   :align: center\n   :alt: settings.json\n\n|\n\nThen, make sure to replace the following settings with the respective values:\n\n.. code:: javascript\n\n   \"llm.modelId\": \"dist/CodeLlama-7b-hf-q4f16_1-MLC\"\n   \"llm.url\": \"http://127.0.0.1:8000/v1/completions\"\n   \"llm.backend\": \"openai\"\n\nThis will enable the extension to send OpenAI compatible requests to the MLC Serve API. Also, feel free to tune the API parameters. Please refer to our :ref:`deploy-rest-api` documentation for more details about these API parameters.\n\n.. code:: javascript\n\n   \"llm.requestBody\": {\n      \"best_of\": 1,\n      \"frequency_penalty\": 0.0,\n      \"presence_penalty\": 0.0,\n      \"logprobs\": false,\n      \"top_logprobs\": 0,\n      \"logit_bias\": null,\n      \"max_tokens\": 128,\n      \"seed\": null,\n      \"stop\": null,\n      \"suffix\": null,\n      \"temperature\": 1.0,\n      \"top_p\": 1.0\n   }\n\nThe llm-ls extension supports a variety of different model code completion templates. Choose the one that best matches your model, i.e., the template with the correct tokenizer and Fill in the Middle tokens.\n\n.. figure:: /_static/img/ide_code_templates.png\n   :width: 375\n   :align: center\n   :alt: llm-ls templates\n\n|\n\nAfter everything is all set, the extension will be ready to use the responses from the MLC Serve API to provide off-the-shelf code completion on your IDE.\n\n.. figure:: /_static/img/code_completion.png\n   :width: 700\n   :align: center\n   :alt: IDE Code Completion\n\n|\n\nConclusion\n----------\n\nPlease, let us know if you have any questions. Feel free to open an issue on the `MLC LLM repo <https://github.com/mlc-ai/mlc-llm/issues>`__!\n"
  },
  {
    "path": "docs/deploy/ios.rst",
    "content": ".. _deploy-ios:\n\niOS Swift SDK\n=============\n\n.. contents:: Table of Contents\n   :local:\n   :depth: 2\n\nThe MLC LLM iOS app can be installed in two ways: through the pre-built package or by building from the source.\nIf you are an iOS user looking to try out the models, the pre-built package is recommended. If you are a\ndeveloper seeking to integrate new features into the package, building the iOS package from the source is required.\n\nUse Pre-built iOS App\n---------------------\nThe MLC Chat app is now available in App Store at no cost. You can download and explore it by simply clicking the button below:\n\n    .. image:: https://developer.apple.com/assets/elements/badges/download-on-the-app-store.svg\n      :width: 135\n      :target: https://apps.apple.com/us/app/mlc-chat/id6448482937\n\n\nBuild iOS App from Source\n-------------------------\n\nThis section shows how we can build the app from the source.\n\nStep 1. Install Build Dependencies\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nFirst and foremost, please clone the `MLC LLM GitHub repository <https://github.com/mlc-ai/mlc-llm>`_.\nAfter cloning, go to the ``ios/`` directory.\n\n.. code:: bash\n\n   git clone https://github.com/mlc-ai/mlc-llm.git\n   cd mlc-llm\n   git submodule update --init --recursive\n   cd ./ios\n\n\nPlease follow :doc:`/install/mlc_llm` to obtain a binary build of mlc_llm package. Note that this\nis independent from the above source code that we use for iOS package build.\nYou do not need to build mlc_llm for your host and we can use the prebuilt package for that purpose.\n\nWe also need to have the following build dependencies:\n\n* CMake >= 3.24,\n* Git and Git-LFS,\n* `Rust and Cargo <https://www.rust-lang.org/tools/install>`_, which are required by Hugging Face's tokenizer.\n\n.. _ios-build-runtime-and-model-libraries:\n\nStep 2. Build Runtime and Model Libraries\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nThe models to be built for the iOS app are specified in ``MLCChat/mlc-package-config.json``:\nin the ``model_list``,\n\n* ``model`` points to the Hugging Face repository which contains the pre-converted model weights. The iOS app will download model weights from the Hugging Face URL.\n* ``model_id`` is a unique model identifier.\n* ``estimated_vram_bytes`` is an estimation of the vRAM the model takes at runtime.\n* ``\"bundle_weight\": true`` means the model weights of the model will be bundled into the app when building.\n* ``overrides`` specifies some model config parameter overrides.\n\n\nWe have a one-line command to build and prepare all the model libraries:\n\n.. code:: bash\n\n   cd /path/to/MLCChat  # e.g., \"ios/MLCChat\"\n   export MLC_LLM_SOURCE_DIR=/path/to/mlc-llm  # e.g., \"../..\"\n   mlc_llm package\n\nThis command mainly executes the following two steps:\n\n1. **Compile models.** We compile each model in ``model_list`` of ``MLCChat/mlc-package-config.json`` into a binary model library.\n2. **Build runtime and tokenizer.** In addition to the model itself, a lightweight runtime and tokenizer are required to actually run the LLM.\n\nThe command creates a ``./dist/`` directory that contains the runtime and model build output.\nPlease make sure ``dist/`` follows the structure below, except the optional model weights.\n\n.. code::\n\n   dist\n   ├── bundle                   # The directory for mlc-app-config.json (and optionally model weights)\n   │   │                        # that will be bundled into the iOS app.\n   │   ├── mlc-app-config.json  # The app config JSON file.\n   │   └── [optional model weights]\n   └── lib\n      ├── libmlc_llm.a          # A lightweight interface to interact with LLM, tokenizer, and TVM runtime.\n      ├── libmodel_iphone.a     # The compiled model lib.\n      ├── libsentencepiece.a    # SentencePiece tokenizer\n      ├── libtokenizers_cpp.a   # Huggingface tokenizer.\n      └── libtvm_runtime.a      # TVM runtime.\n\n\n.. note::\n\n   We leverage a local JIT cache to avoid repetitive compilation of the same input.\n   However, sometimes it is helpful to force rebuild when we have a new compiler update\n   or when something goes wrong with the cached library.\n   You can do so by setting the environment variable ``MLC_JIT_POLICY=REDO``\n\n   .. code:: bash\n\n      MLC_JIT_POLICY=REDO mlc_llm package\n\n.. _ios-bundle-model-weights:\n\nStep 3. (Optional) Bundle model weights into the app\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nBy default, we download the model weights from Hugging Face when running the app.\n**As an option,**, we bundle model weights into the app:\nset the field ``\"bundle_weight\": true`` for any model you want to bundle weights\nin ``MLCChat/mlc-package-config.json``, and run ``mlc_llm package`` again.\nBelow is an example:\n\n.. code:: json\n\n   {\n      \"device\": \"iphone\",\n      \"model_list\": [\n         {\n            \"model\": \"HF://mlc-ai/gemma-2b-it-q4f16_1-MLC\",\n            \"model_id\": \"gemma-2b-q4f16_1\",\n            \"estimated_vram_bytes\": 3000000000,\n            \"overrides\": {\n               \"prefill_chunk_size\": 128\n            },\n            \"bundle_weight\": true\n         }\n      ]\n   }\n\nThe outcome of running ``mlc_llm package`` should be as follows:\n\n.. code::\n\n   dist\n   ├── bundle\n   │   ├── gemma-2b-q4f16_1   # The model weights that will be bundled into the app.\n   │   └── mlc-app-config.json\n   └── ...\n\n.. _ios-build-app:\n\nStep 4. Build iOS App\n^^^^^^^^^^^^^^^^^^^^^\n\nOpen ``./ios/MLCChat/MLCChat.xcodeproj`` using Xcode. Note that you will need an\nApple Developer Account to use Xcode, and you may be prompted to use\nyour own developer team credential and product bundle identifier.\n\nEnsure that all the necessary dependencies and configurations are\ncorrectly set up in the Xcode project.\n\nOnce you have made the necessary changes, build the iOS app using Xcode.\nIf you have an Apple Silicon Mac, you can select target \"My Mac (designed for iPad)\"\nto run on your Mac. You can also directly run it on your iPad or iPhone.\n\n.. image:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/xcode-build.jpg\n   :align: center\n   :width: 60%\n\n|\n\nCustomize the App\n-----------------\n\nWe can customize the models built in the iOS app by customizing `MLCChat/mlc-package-config.json <https://github.com/mlc-ai/mlc-llm/blob/main/ios/MLCChat/mlc-package-config.json>`_.\nWe introduce each field of the JSON file here.\n\nEach entry in ``\"model_list\"`` of the JSON file has the following fields:\n\n``model``\n   (Required) The path to the MLC-converted model to be built into the app.\n\n   It can be either a Hugging Face URL (e.g., ``\"model\": \"HF://mlc-ai/phi-2-q4f16_1-MLC\"```), or a path to a local model directory which contains converted model weights (e.g., ``\"model\": \"../dist/gemma-2b-q4f16_1\"``). Please check out :ref:`convert-weights-via-MLC` if you want to build local model into the app.\n\n   *Note: the local path (if relative) is relative to the* ``ios/`` *directory.*\n\n``model_id``\n  (Required) A unique local identifier to identify the model.\n  It can be an arbitrary one.\n\n``estimated_vram_bytes``\n   (Required) Estimated requirements of vRAM to run the model.\n\n``bundle_weight``\n   (Optional) A boolean flag indicating whether to bundle model weights into the app. See :ref:`ios-bundle-model-weights`.\n\n``overrides``\n   (Optional) A dictionary to override the default model context window size (to limit the KV cache size) and prefill chunk size (to limit the model temporary execution memory).\n   Example:\n\n   .. code:: json\n\n      {\n         \"device\": \"iphone\",\n         \"model_list\": [\n            {\n                  \"model\": \"HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\",\n                  \"model_id\": \"RedPajama-INCITE-Chat-3B-v1-q4f16_1\",\n                  \"estimated_vram_bytes\": 2960000000,\n                  \"overrides\": {\n                     \"context_window_size\": 512,\n                     \"prefill_chunk_size\": 128\n                  }\n            }\n         ]\n      }\n\n``model_lib``\n   (Optional) A string specifying the system library prefix to use for the model.\n   Usually this is used when you want to build multiple model variants with the same architecture into the app.\n   **This field does not affect any app functionality.**\n   The ``\"model_lib_path_for_prepare_libs\"`` introduced below is also related.\n   Example:\n\n   .. code:: json\n\n      {\n         \"device\": \"iphone\",\n         \"model_list\": [\n            {\n                  \"model\": \"HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\",\n                  \"model_id\": \"RedPajama-INCITE-Chat-3B-v1-q4f16_1\",\n                  \"estimated_vram_bytes\": 2960000000,\n                  \"model_lib\": \"gpt_neox_q4f16_1\"\n            }\n         ]\n      }\n\n\nBesides ``model_list`` in ``MLCChat/mlc-package-config.json``,\nyou can also **optionally** specify a dictionary of ``\"model_lib_path_for_prepare_libs\"``,\n**if you want to use model libraries that are manually compiled**.\nThe keys of this dictionary should be the ``model_lib`` that specified in model list,\nand the values of this dictionary are the paths (absolute, or relative) to the manually compiled model libraries.\nThe model libraries specified in ``\"model_lib_path_for_prepare_libs\"`` will be built into the app when running ``mlc_llm package``.\nExample:\n\n.. code:: json\n\n   {\n      \"device\": \"iphone\",\n      \"model_list\": [\n         {\n               \"model\": \"HF://mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\",\n               \"model_id\": \"RedPajama-INCITE-Chat-3B-v1-q4f16_1\",\n               \"estimated_vram_bytes\": 2960000000,\n               \"model_lib\": \"gpt_neox_q4f16_1\"\n         }\n      ],\n      \"model_lib_path_for_prepare_libs\": {\n         \"gpt_neox_q4f16_1\": \"../../dist/lib/RedPajama-INCITE-Chat-3B-v1-q4f16_1-iphone.tar\"\n      }\n   }\n\n\nBring Your Own Model\n--------------------\n\nThis section introduces how to build your own model into the iOS app.\nWe use the example of `NeuralHermes <https://huggingface.co/mlabonne/NeuralHermes-2.5-Mistral-7B>`_ model, which a variant of Mistral model.\n\n.. note::\n\n  This section largely replicates :ref:`convert-weights-via-MLC`.\n  See that page for more details. Note that the weights are shared across\n  all platforms in MLC.\n\n**Step 1. Clone from HF and convert_weight**\n\nYou can be under the mlc-llm repo, or your own working directory. Note that all platforms\ncan share the same compiled/quantized weights. See :ref:`compile-command-specification`\nfor specification of ``convert_weight``.\n\n.. code:: shell\n\n    # Create directory\n    mkdir -p dist/models && cd dist/models\n    # Clone HF weights\n    git lfs install\n    git clone https://huggingface.co/mlabonne/NeuralHermes-2.5-Mistral-7B\n    cd ../..\n    # Convert weight\n    mlc_llm convert_weight ./dist/models/NeuralHermes-2.5-Mistral-7B/ \\\n        --quantization q4f16_1 \\\n        -o dist/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC\n\n**Step 2. Generate MLC Chat Config**\n\nUse ``mlc_llm gen_config`` to generate ``mlc-chat-config.json`` and process tokenizers.\nSee :ref:`compile-command-specification` for specification of ``gen_config``.\n\n.. code:: shell\n\n    mlc_llm gen_config ./dist/models/NeuralHermes-2.5-Mistral-7B/ \\\n        --quantization q3f16_1 --conv-template neural_hermes_mistral \\\n        -o dist/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC\n\nFor the ``conv-template``, `conversation_template.py <https://github.com/mlc-ai/mlc-llm/blob/main/python/mlc_llm/conversation_template.py>`__\ncontains a full list of conversation templates that MLC provides.\n\nIf the model you are adding requires a new conversation template, you would need to add your own.\nFollow `this PR <https://github.com/mlc-ai/mlc-llm/pull/2163>`__ as an example.\nWe look up the template to use with the ``conv_template`` field in ``mlc-chat-config.json``.\n\nFor more details, please see :ref:`configure-mlc-chat-json`.\n\n**Step 3. Upload weights to HF**\n\n.. code:: shell\n\n    # First, please create a repository on Hugging Face.\n    # With the repository created, run\n    git lfs install\n    git clone https://huggingface.co/my-huggingface-account/my-mistral-weight-huggingface-repo\n    cd my-mistral-weight-huggingface-repo\n    cp path/to/mlc-llm/dist/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC/* .\n    git add . && git commit -m \"Add mistral model weights\"\n    git push origin main\n\nAfter successfully following all steps, you should end up with a Huggingface repo similar to\n`NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC <https://huggingface.co/mlc-ai/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC>`__,\nwhich includes the converted/quantized weights, the ``mlc-chat-config.json``, and tokenizer files.\n\n\n**Step 4. Register in Model List**\n\nFinally, we add the model into the ``model_list`` of\n`MLCChat/mlc-package-config.json <https://github.com/mlc-ai/mlc-llm/blob/main/ios/MLCChat/mlc-package-config.json>`_ by specifying the Hugging Face link as ``model``:\n\n.. code:: json\n\n   {\n      \"device\": \"iphone\",\n      \"model_list\": [\n         {\n               \"model\": \"HF://mlc-ai/NeuralHermes-2.5-Mistral-7B-q3f16_1-MLC\",\n               \"model_id\": \"Mistral-7B-Instruct-v0.2-q3f16_1\",\n               \"estimated_vram_bytes\": 3316000000,\n         }\n      ]\n   }\n\n\nNow, go through :ref:`ios-build-runtime-and-model-libraries` and :ref:`ios-build-app` again.\nThe app will use the ``NeuralHermes-Mistral`` model you just added.\n\n\nBuild Apps with MLC Swift API\n-----------------------------\n\nWe also provide a Swift package that you can use to build\nyour own app. The package is located under ``ios/MLCSwift``.\n\n- First, create ``mlc-package-config.json`` in your project folder.\n  You do so by copying the files in MLCChat folder.\n  Run ``mlc_llm package``.\n  This will give us the necessary libraries under ``/path/to/project/dist``.\n- Under \"Build phases\", add ``/path/to/project/dist/bundle`` this will copying\n  this folder into your app to include bundled weights and configs.\n- Add ``ios/MLCSwift`` package to your app in Xcode.\n  Under \"Frameworks, Libraries, and Embedded Content\", click add package dependencies\n  and add local package that points to ``ios/MLCSwift``.\n- Finally, we need to add the libraries dependencies. Under build settings:\n\n  - Add library search path ``/path/to/project/dist/lib``.\n  - Add the following items to \"other linker flags\".\n\n   .. code::\n\n      -Wl,-all_load\n      -lmodel_iphone\n      -lmlc_llm -ltvm_runtime\n      -ltokenizers_cpp\n      -lsentencepiece\n      -ltokenizers_c\n\n\nYou can then import the `MLCSwift` package into your app.\nThe following code shows an illustrative example of how to use the chat module.\n\n.. code:: swift\n\n   import MLCSwift\n\n   func runExample() async {\n      let engine = MLCEngine()\n      let modelPath = \"/path/to/model/weights\"\n      let modelLib = \"model-lib-name\"\n\n      await engine.reload(modelPath: modelPath, modelLib: modelLib)\n\n      // run chat completion as in OpenAI API style\n      for await res in await engine.chat.completions.create(\n            messages: [\n               ChatCompletionMessage(\n                  role: .user,\n                  content: \"What is the meaning of life?\"\n               )\n            ]\n      ) {\n         print(res.choices[0].delta.content!.asText())\n      }\n   }\n\nCheckout `MLCEngineExample <https://github.com/mlc-ai/mlc-llm/blob/main/ios/MLCEngineExample>`_\nfor a minimal starter example.\n"
  },
  {
    "path": "docs/deploy/mlc_chat_config.rst",
    "content": ".. _configure-mlc-chat-json:\n\nCustomize MLC Chat Config\n=========================\n\n``mlc-chat-config.json`` is required for both compile-time and runtime, hence serving two purposes:\n\n1. Specify how we compile a model (shown in :ref:`compile-model-libraries`), and\n2. Specify conversation behavior in runtime.\n\n**This page focuses on the second purpose.** We explain the components of a chat\nconfiguration and how to customize them by modifying the file. Additionally,\nthe runtimes also provide APIs to optionally override some of the configurations.\n\nIn runtime, this file is stored under the directory of each compiled model\n(e.g. `RedPajama chat config <https://huggingface.co/mlc-ai/mlc-chat-RedPajama-INCITE-Chat-3B-v1-q4f16_1/blob/main/mlc-chat-config.json>`__).\n\n\n.. _struct-mlc-chat-conv:\n\nStructure of MLCChat Configuration\n----------------------------------\n\nBelow is the ``mlc-chat-config.json`` file corresponding to Llama2 model:\n\n.. code:: json\n\n  // mlc-chat-config.json\n  {\n    // 1. Metadata used to specify how to compile a model\n    \"model_type\": \"llama\",\n    \"quantization\": \"q4f16_1\",\n    \"version\": \"0.1.0\",\n    \"model_config\": {\n      \"hidden_size\": 4096,\n      \"intermediate_size\": 11008,\n      // more fields here...\n    },\n    \"vocab_size\": 32000,\n    \"context_window_size\": 4096,\n    \"sliding_window_size\": -1,\n    \"prefill_chunk_size\": 4096,\n    \"tensor_parallel_shards\": 1,\n\n    // 2. Tokenizer-related fields\n    \"pad_token_id\": 0,\n    \"bos_token_id\": 1,\n    \"eos_token_id\": 2,\n    \"tokenizer_files\": [\n      \"tokenizer.model\",\n      \"tokenizer.json\",\n      \"tokenizer_config.json\"\n    ]\n\n    // 3. Conversation template related fields\n    \"conv_template\": {\n      \"name\": \"llama-2\",\n      \"system_template\": \"[INST] <<SYS>>\\n{system_message}\\n<</SYS>>\\n\\n \",\n      \"system_message\": \"You are a helpful, respectful and honest assistant.\",\n      // more fields here...\n    },\n\n    // 4. Chat related fields that affect runtime behavior\n    \"temperature\": 0.6,\n    \"repetition_penalty\": 1.0,\n    \"top_p\": 0.9\n  }\n\n.. note::\n  Fields in the first part of ``mlc-chat-config.json`` (e.g. ``context-window-size``)\n  is only for compile-time. Changing them during runtime may lead to unexpected behavior.\n\n**As shown above, the file is divided into three parts. We focus on the third part, which\ncan be customized to change the behavior of the model.**\n\n``conv_template``\n  .. note::\n    Legacy ``mlc-chat-config.json`` may specify a string for this field to look up a registered conversation\n    template. It will be deprecated in the future. Re-generate config using the latest version of mlc_llm\n    to make sure this field is a complete JSON object.\n\n  The conversation template that this chat uses. For more information, please refer to :ref:`conversation structure <struct-conv>`.\n\n``temperature``\n  The temperature applied to logits before sampling. The default value is ``0.7``. A higher temperature encourages more diverse outputs, while a lower temperature produces more deterministic outputs.\n\n``repetition_penalty``\n  The repetition penalty controls the likelihood of the model generating repeated texts. The default value is set to ``1.0``, indicating that no repetition penalty is applied. Increasing the value reduces the likelihood of repeat text generation. However, setting a high ``repetition_penalty`` may result in the model generating meaningless texts. The ideal choice of repetition penalty may vary among models.\n\n  For more details on how repetition penalty controls text generation, please check out the `CTRL paper <https://arxiv.org/pdf/1909.05858.pdf>`_.\n\n``top_p``\n  This parameter determines the set of tokens from which we sample during decoding. The default value is set to ``0.95``. At each step, we select tokens from the minimal set that has a cumulative probability exceeding the ``top_p`` parameter.\n\n  For additional information on top-p sampling, please refer to this `blog post <https://huggingface.co/blog/how-to-generate#top-p-nucleus-sampling>`_.\n\n\n.. _struct-conv:\n\nConversation Structure\n^^^^^^^^^^^^^^^^^^^^^^\n\nMLC-LLM provided a set of pre-defined conversation templates, which you can directly use by\nspecifying ``--conv-template [name]`` when generating config. Below is a list (not complete) of\nsupported conversation templates:\n\n- ``llama-2``\n- ``mistral_default``\n- ``chatml``\n- ``phi-2``\n- ...\n\nPlease refer to `conversation_template <https://github.com/mlc-ai/mlc-llm/blob/main/python/mlc_llm/conversation_template>`_ directory for the full list of supported templates and their implementations.\n\nBelow is a generic structure of a JSON conversation configuration (we use vicuna as an example):\n\n.. code:: json\n\n  // mlc-chat-config.json\n  {\n    // ...\n    \"conv_template\": {\n      \"name\": \"llama-2\",\n      \"system_template\": \"[INST] <<SYS>>\\n{system_message}\\n<</SYS>>\\n\\n \",\n      \"system_message\": \"You are a helpful, respectful and honest assistant.\",\n      \"roles\": {\n        \"user\": \"[INST]\",\n        \"assistant\": \"[/INST]\",\n        \"tool\": \"[INST]\"\n      },\n      \"role_templates\": {\n        \"user\": \"{user_message}\",\n        \"assistant\": \"{assistant_message}\",\n        \"tool\": \"{tool_message}\"\n      },\n      \"messages\": [],\n      \"seps\": [\n        \" \"\n      ],\n      \"role_content_sep\": \" \",\n      \"role_empty_sep\": \" \",\n      \"stop_str\": [\n        \"[INST]\"\n      ],\n      \"stop_token_ids\": [\n        2\n      ],\n      \"function_string\": \"\",\n      \"use_function_calling\": false\n    }\n  }\n\n``name``\n    Name of the conversation.\n``system_template``\n    The system prompt template, it optionally contains the system\n    message placeholder, and the placeholder will be replaced with\n    the system message below.\n``system_message``\n    The content of the system prompt (without the template format).\n``system_prefix_token_ids``\n    The system token ids to be prepended at the beginning of tokenized\n    generated prompt.\n``roles``\n    The conversation roles\n``role_templates``\n    The roles prompt template, it optionally contains the defaults\n    message placeholders and will be replaced by actual content\n``messages``\n    The conversation history messages.\n    Each message is a pair of strings, denoting \"(role, content)\".\n    The content can be None.\n``seps``\n    An array of strings indicating the separators to be used after a user\n    message and a model message respectively.\n``role_content_sep``\n    The separator between the role and the content in a message.\n``role_empty_sep``\n    The separator between the role and empty contents.\n``stop_str``\n    When the ``stop_str`` is encountered, the model will stop generating output.\n``stop_token_ids``\n    A list of token IDs that act as stop tokens.\n``function_string``\n    The function calling string.\n``use_function_calling``\n    Whether using function calling or not, helps check for output message format in API call.\n\n\nGiven a conversation template, the corresponding prompt generated out\nfrom it is in the following format:\n\n.. code:: text\n\n  <<system>><<messages[0][0]>><<role_content_sep>><<messages[0][1]>><<seps[0]>>\n            <<messages[1][0]>><<role_content_sep>><<messages[1][1]>><<seps[1]>>\n            ...\n            <<messages[2][0]>><<role_content_sep>><<messages[2][1]>><<seps[0]>>\n            <<roles[1]>><<role_empty_sep>>\n"
  },
  {
    "path": "docs/deploy/python_engine.rst",
    "content": ".. _deploy-python-engine:\n\nPython API\n==========\n\n.. note::\n  This page introduces the Python API with MLCEngine in MLC LLM.\n\n.. contents:: Table of Contents\n  :local:\n  :depth: 2\n\n\nMLC LLM provides Python API through classes :class:`mlc_llm.MLCEngine` and :class:`mlc_llm.AsyncMLCEngine`\nwhich **support full OpenAI API completeness** for easy integration into other Python projects.\n\nThis page introduces how to use the engines in MLC LLM.\nThe Python API is a part of the MLC-LLM package, which we have prepared pre-built pip wheels via\nthe :ref:`installation page <install-mlc-packages>`.\n\n\nVerify Installation\n-------------------\n\n.. code:: bash\n\n  python -c \"from mlc_llm import MLCEngine; print(MLCEngine)\"\n\nYou are expected to see the output of ``<class 'mlc_llm.serve.engine.MLCEngine'>``.\n\nIf the command above results in error, follow :ref:`install-mlc-packages` to install prebuilt pip\npackages or build MLC LLM from source.\n\n\nRun MLCEngine\n-------------\n\n:class:`mlc_llm.MLCEngine` provides the interface of OpenAI chat completion synchronously.\n:class:`mlc_llm.MLCEngine` does not batch concurrent request due to the synchronous design,\nand please use :ref:`AsyncMLCEngine <python-engine-async-llm-engine>` for request batching process.\n\n**Stream Response.** In :ref:`quick-start` and :ref:`introduction-to-mlc-llm`,\nwe introduced the basic use of :class:`mlc_llm.MLCEngine`.\n\n.. code:: python\n\n  from mlc_llm import MLCEngine\n\n  # Create engine\n  model = \"HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC\"\n  engine = MLCEngine(model)\n\n  # Run chat completion in OpenAI API.\n  for response in engine.chat.completions.create(\n      messages=[{\"role\": \"user\", \"content\": \"What is the meaning of life?\"}],\n      model=model,\n      stream=True,\n  ):\n      for choice in response.choices:\n          print(choice.delta.content, end=\"\", flush=True)\n  print(\"\\n\")\n\n  engine.terminate()\n\nThis code example first creates an :class:`mlc_llm.MLCEngine` instance with the 8B Llama-3 model.\n**We design the Python API** :class:`mlc_llm.MLCEngine` **to align with OpenAI API**,\nwhich means you can use :class:`mlc_llm.MLCEngine` in the same way of using\n`OpenAI's Python package <https://github.com/openai/openai-python?tab=readme-ov-file#usage>`_\nfor both synchronous and asynchronous generation.\n\n**Non-stream Response.** The code example above uses the synchronous chat completion\ninterface and iterate over all the stream responses.\nIf you want to run without streaming, you can run\n\n.. code:: python\n\n  response = engine.chat.completions.create(\n      messages=[{\"role\": \"user\", \"content\": \"What is the meaning of life?\"}],\n      model=model,\n      stream=False,\n  )\n  print(response)\n\nPlease refer to `OpenAI's Python package <https://github.com/openai/openai-python?tab=readme-ov-file#usage>`_\nand `OpenAI chat completion API <https://platform.openai.com/docs/api-reference/chat/create>`_\nfor the complete chat completion interface.\n\n.. note::\n\n  If you want to enable tensor parallelism to run LLMs on multiple GPUs,\n  please specify argument ``model_config_overrides`` in MLCEngine constructor.\n  For example,\n\n  .. code:: python\n\n    from mlc_llm import MLCEngine\n    from mlc_llm.serve.config import EngineConfig\n\n    model = \"HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC\"\n    engine = MLCEngine(\n        model,\n        engine_config=EngineConfig(tensor_parallel_shards=2),\n    )\n\n\n.. _python-engine-async-llm-engine:\n\nRun AsyncMLCEngine\n------------------\n\n:class:`mlc_llm.AsyncMLCEngine` provides the interface of OpenAI chat completion with\nasynchronous features.\n**We recommend using** :class:`mlc_llm.AsyncMLCEngine` **to batch concurrent request for better throughput.**\n\n**Stream Response.** The core use of :class:`mlc_llm.AsyncMLCEngine` for stream responses is as follows.\n\n.. code:: python\n\n  async for response in await engine.chat.completions.create(\n    messages=[{\"role\": \"user\", \"content\": \"What is the meaning of life?\"}],\n    model=model,\n    stream=True,\n  ):\n    for choice in response.choices:\n        print(choice.delta.content, end=\"\", flush=True)\n\n.. collapse:: The collapsed is a complete runnable example of AsyncMLCEngine in Python.\n\n  .. code:: python\n\n    import asyncio\n    from typing import Dict\n\n    from mlc_llm.serve import AsyncMLCEngine\n\n    model = \"HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC\"\n    prompts = [\n        \"Write a three-day travel plan to Pittsburgh.\",\n        \"What is the meaning of life?\",\n    ]\n\n\n    async def test_completion():\n        # Create engine\n        async_engine = AsyncMLCEngine(model=model)\n\n        num_requests = len(prompts)\n        output_texts: Dict[str, str] = {}\n\n        async def generate_task(prompt: str):\n            async for response in await async_engine.chat.completions.create(\n                messages=[{\"role\": \"user\", \"content\": prompt}],\n                model=model,\n                stream=True,\n            ):\n                if response.id not in output_texts:\n                    output_texts[response.id] = \"\"\n                output_texts[response.id] += response.choices[0].delta.content\n\n        tasks = [asyncio.create_task(generate_task(prompts[i])) for i in range(num_requests)]\n        await asyncio.gather(*tasks)\n\n        # Print output.\n        for request_id, output in output_texts.items():\n            print(f\"Output of request {request_id}:\\n{output}\\n\")\n\n        async_engine.terminate()\n\n\n    asyncio.run(test_completion())\n\n|\n\n**Non-stream Response.** Similarly, :class:`mlc_llm.AsyncEngine` provides the non-stream response\ninterface.\n\n.. code:: python\n\n  response = await engine.chat.completions.create(\n    messages=[{\"role\": \"user\", \"content\": \"What is the meaning of life?\"}],\n    model=model,\n    stream=False,\n  )\n  print(response)\n\nPlease refer to `OpenAI's Python package <https://github.com/openai/openai-python?tab=readme-ov-file#usage>`_\nand `OpenAI chat completion API <https://platform.openai.com/docs/api-reference/chat/create>`_\nfor the complete chat completion interface.\n\n.. note::\n\n  If you want to enable tensor parallelism to run LLMs on multiple GPUs,\n  please specify argument ``model_config_overrides`` in AsyncMLCEngine constructor.\n  For example,\n\n  .. code:: python\n\n    from mlc_llm import AsyncMLCEngine\n    from mlc_llm.serve.config import EngineConfig\n\n    model = \"HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC\"\n    engine = AsyncMLCEngine(\n        model,\n        engine_config=EngineConfig(tensor_parallel_shards=2),\n    )\n\n\nEngine Mode\n-----------\n\nTo ease the engine configuration, the constructors of :class:`mlc_llm.MLCEngine` and\n:class:`mlc_llm.AsyncMLCEngine` have an optional argument ``mode``,\nwhich falls into one of the three options ``\"local\"``, ``\"interactive\"`` or ``\"server\"``.\nThe default mode is ``\"local\"``.\n\nEach mode denotes a pre-defined configuration of the engine to satisfy different use cases.\nThe choice of the mode controls the request concurrency of the engine,\nas well as engine's KV cache token capacity (or in other words, the maximum\nnumber of tokens that the engine's KV cache can hold),\nand further affects the GPU memory usage of the engine.\n\nIn short,\n\n- mode ``\"local\"`` uses low request concurrency and low KV cache capacity, which is suitable for cases where **concurrent requests are not too many, and the user wants to save GPU memory usage**.\n- mode ``\"interactive\"`` uses 1 as the request concurrency and low KV cache capacity, which is designed for **interactive use cases** such as chats and conversations.\n- mode ``\"server\"`` uses as much request concurrency and KV cache capacity as possible. This mode aims to **fully utilize the GPU memory for large server scenarios** where concurrent requests may be many.\n\n**For system benchmark, please select mode** ``\"server\"``.\nPlease refer to :ref:`python-engine-api-reference` for detailed documentation of the engine mode.\n\n\nDeploy Your Own Model with Python API\n-------------------------------------\n\nThe :ref:`introduction page <introduction-deploy-your-own-model>` introduces how we can deploy our\nown models with MLC LLM.\nThis section introduces how you can use the model weights you convert and the model library you build\nin :class:`mlc_llm.MLCEngine` and :class:`mlc_llm.AsyncMLCEngine`.\n\nWe use the `Phi-2 <https://huggingface.co/microsoft/phi-2>`_ as the example model.\n\n**Specify Model Weight Path.** Assume you have converted the model weights for your own model,\nyou can construct a :class:`mlc_llm.MLCEngine` as follows:\n\n.. code:: python\n\n  from mlc_llm import MLCEngine\n\n  model = \"models/phi-2\"  # Assuming the converted phi-2 model weights are under \"models/phi-2\"\n  engine = MLCEngine(model)\n\n\n**Specify Model Library Path.** Further, if you build the model library on your own,\nyou can use it in :class:`mlc_llm.MLCEngine` by passing the library path through argument ``model_lib``.\n\n.. code:: python\n\n  from mlc_llm import MLCEngine\n\n  model = \"models/phi-2\"\n  model_lib = \"models/phi-2/lib.so\"  # Assuming the phi-2 model library is built at \"models/phi-2/lib.so\"\n  engine = MLCEngine(model, model_lib=model_lib)\n\n\nThe same applies to :class:`mlc_llm.AsyncMLCEngine`.\n\n\n.. _python-engine-api-reference:\n\nAPI Reference\n-------------\n\nThe :class:`mlc_llm.MLCEngine` and :class:`mlc_llm.AsyncMLCEngine` classes provide the following constructors.\n\nThe MLCEngine and AsyncMLCEngine have full OpenAI API completeness.\nPlease refer to `OpenAI's Python package <https://github.com/openai/openai-python?tab=readme-ov-file#usage>`_\nand `OpenAI chat completion API <https://platform.openai.com/docs/api-reference/chat/create>`_\nfor the complete chat completion interface.\n\n.. currentmodule:: mlc_llm\n\n.. autoclass:: MLCEngine\n  :members:\n  :exclude-members: evaluate\n  :undoc-members:\n  :show-inheritance:\n\n  .. automethod:: __init__\n\n.. autoclass:: AsyncMLCEngine\n  :members:\n  :exclude-members: evaluate\n  :undoc-members:\n  :show-inheritance:\n\n  .. automethod:: __init__\n"
  },
  {
    "path": "docs/deploy/rest.rst",
    "content": ".. _deploy-rest-api:\n\nREST API\n========\n\n.. contents:: Table of Contents\n   :local:\n   :depth: 2\n\nWe provide `REST API <https://www.ibm.com/topics/rest-apis#:~:text=the%20next%20step-,What%20is%20a%20REST%20API%3F,representational%20state%20transfer%20architectural%20style.>`_\nfor a user to interact with MLC-LLM in their own programs.\n\nInstall MLC-LLM Package\n------------------------\n\nSERVE is a part of the MLC-LLM package, installation instruction for which can be found :ref:`here <install-mlc-packages>`. Once you have install the MLC-LLM package, you can run the following command to check if the installation was successful:\n\n.. code:: bash\n\n   mlc_llm serve --help\n\nYou should see serve help message if the installation was successful.\n\nQuick Start\n------------\n\nThis section provides a quick start guide to work with MLC-LLM REST API. To launch a server, run the following command:\n\n.. code:: bash\n\n   mlc_llm serve MODEL [--model-lib PATH-TO-MODEL-LIB]\n\nwhere ``MODEL`` is the model folder after compiling with :ref:`MLC-LLM build process <compile-model-libraries>`. Information about other arguments can be found under :ref:`Launch the server <rest_launch_server>` section.\n\nOnce you have launched the Server, you can use the API in your own program to send requests. Below is an example of using the API to interact with MLC-LLM in Python without Streaming (suppose the server is running on ``http://127.0.0.1:8080/``):\n\n.. code:: bash\n\n   import requests\n\n   # Get a response using a prompt without streaming\n   payload = {\n      \"model\": \"./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/\",\n      \"messages\": [\n         {\"role\": \"user\", \"content\": \"Write a haiku about apples.\"},\n      ],\n      \"stream\": False,\n      # \"n\": 1,\n      \"max_tokens\": 300,\n   }\n   r = requests.post(\"http://127.0.0.1:8080/v1/chat/completions\", json=payload)\n   choices = r.json()[\"choices\"]\n   for choice in choices:\n      print(f\"{choice['message']['content']}\\n\")\n\nRun CLI with Multi-GPU\n----------------------\n\nIf you want to enable tensor parallelism to run LLMs on multiple GPUs, please specify argument ``--overrides \"tensor_parallel_shards=$NGPU\"``. For example,\n\n.. code:: shell\n\n   mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC --overrides \"tensor_parallel_shards=2\"\n\n------------------------------------------------\n\n\n.. _rest_launch_server:\n\n\nLaunch the Server\n-----------------\n\nTo launch the MLC Server for MLC-LLM, run the following command in your terminal.\n\n.. code:: bash\n\n   mlc_llm serve MODEL [--model-lib PATH-TO-MODEL-LIB] [--device DEVICE] [--mode MODE] \\\n       [--additional-models ADDITIONAL-MODELS] \\\n       [--speculative-mode SPECULATIVE-MODE] \\\n       [--overrides OVERRIDES] \\\n       [--enable-tracing] \\\n       [--host HOST] \\\n       [--port PORT] \\\n       [--allow-credentials] \\\n       [--allowed-origins ALLOWED_ORIGINS] \\\n       [--allowed-methods ALLOWED_METHODS] \\\n       [--allowed-headers ALLOWED_HEADERS]\n\n\nMODEL                  The model folder after compiling with MLC-LLM build process. The parameter\n                       can either be the model name with its quantization scheme\n                       (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model\n                       folder. In the former case, we will use the provided name to search\n                       for the model folder over possible paths.\n\n--model-lib            A field to specify the full path to the model library file to use (e.g. a ``.so`` file).\n--device               The description of the device to run on. User should provide a string in the\n                       form of ``device_name:device_id`` or ``device_name``, where ``device_name`` is one of\n                       ``cuda``, ``metal``, ``vulkan``, ``rocm``, ``opencl``, ``auto`` (automatically detect the\n                       local device), and ``device_id`` is the device id to run on. The default value is ``auto``,\n                       with the device id set to 0 for default.\n--mode                 The engine mode in MLC LLM.\n                       We provide three preset modes: ``local``, ``interactive`` and ``server``.\n                       The default mode is ``local``.\n\n                       The choice of mode decides the values of \"max_num_sequence\", \"max_total_sequence_length\"\n                       and \"prefill_chunk_size\" when they are not explicitly specified.\n\n                       1. Mode \"local\" refers to the local server deployment which has low\n                       request concurrency. So the max batch size will be set to 4, and max\n                       total sequence length and prefill chunk size are set to the context\n                       window size (or sliding window size) of the model.\n\n                       2. Mode \"interactive\" refers to the interactive use of server, which\n                       has at most 1 concurrent request. So the max batch size will be set to 1,\n                       and max total sequence length and prefill chunk size are set to the context\n                       window size (or sliding window size) of the model.\n\n                       3. Mode \"server\" refers to the large server use case which may handle\n                       many concurrent request and want to use GPU memory as much as possible.\n                       In this mode, we will automatically infer the largest possible max batch\n                       size and max total sequence length.\n\n                       You can manually specify arguments \"max_num_sequence\", \"max_total_seq_length\" and\n                       \"prefill_chunk_size\" via ``--overrides`` to override the automatic inferred values.\n                       For example: ``--overrides \"max_num_sequence=32;max_total_seq_length=4096\"``.\n--additional-models    The model paths and (optional) model library paths of additional models (other\n                       than the main model).\n\n                       When engine is enabled with speculative decoding, additional models are needed.\n                       **We only support one additional model for speculative decoding now.**\n                       The way of specifying the additional model is:\n                       ``--additional-models model_path_1`` or\n                       ``--additional-models model_path_1,model_lib_1``.\n\n                       When the model lib of a model is not given, JIT model compilation will be activated\n                       to compile the model automatically.\n--speculative-mode     The speculative decoding mode. Right now four options are supported:\n\n                       - ``disable``, where speculative decoding is not enabled,\n\n                       - ``small_draft``, denoting the normal speculative decoding (small draft) style,\n\n                       - ``eagle``, denoting the eagle-style speculative decoding.\n\n                       - ``medusa``, denoting the medusa-style speculative decoding.\n--overrides            Overriding extra configurable fields of EngineConfig.\n\n                       Supporting fields that can be be overridden: ``tensor_parallel_shards``, ``max_num_sequence``,\n                       ``max_total_seq_length``, ``prefill_chunk_size``, ``max_history_size``, ``gpu_memory_utilization``,\n                       ``spec_draft_length``, ``prefix_cache_max_num_recycling_seqs``, ``context_window_size``,\n                       ``sliding_window_size``, ``attention_sink_size``.\n\n                       Please check out the documentation of EngineConfig in ``mlc_llm/serve/config.py``\n                       for detailed docstring of each field.\n                       Example: ``--overrides \"max_num_sequence=32;max_total_seq_length=4096;tensor_parallel_shards=2\"``\n--enable-tracing       A boolean indicating if to enable event logging for requests.\n--host                 The host at which the server should be started, defaults to ``127.0.0.1``.\n--port                 The port on which the server should be started, defaults to ``8000``.\n--allow-credentials    A flag to indicate whether the server should allow credentials. If set, the server will\n                       include the ``CORS`` header in the response\n--allowed-origins      Specifies the allowed origins. It expects a JSON list of strings, with the default value being ``[\"*\"]``, allowing all origins.\n--allowed-methods      Specifies the allowed methods. It expects a JSON list of strings, with the default value being ``[\"*\"]``, allowing all methods.\n--allowed-headers      Specifies the allowed headers. It expects a JSON list of strings, with the default value being ``[\"*\"]``, allowing all headers.\n\nYou can access ``http://127.0.0.1:PORT/docs`` (replace ``PORT`` with the port number you specified) to see the list of\nsupported endpoints.\n\nAPI Endpoints\n-------------\n\nThe REST API provides the following endpoints:\n\n.. http:get:: /v1/models\n\n------------------------------------------------\n\n   Get a list of models available for MLC-LLM.\n\n**Example**\n\n.. code:: bash\n\n   import requests\n\n   url = \"http://127.0.0.1:8000/v1/models\"\n   headers = {\"accept\": \"application/json\"}\n\n   response = requests.get(url, headers=headers)\n\n   if response.status_code == 200:\n      print(\"Response:\")\n      print(response.json())\n   else:\n      print(\"Error:\", response.status_code)\n\n\n.. http:post:: /v1/chat/completions\n\n------------------------------------------------\n\n   Get a response from MLC-LLM using a prompt, either with or without streaming.\n\n**Chat Completion Request Object**\n\n- **messages** (*List[ChatCompletionMessage]*, required): A sequence of messages that have been exchanged in the conversation so far. Each message in the conversation is represented by a `ChatCompletionMessage` object, which includes the following fields:\n    - **content** (*Optional[Union[str, List[Dict[str, str]]]]*): The text content of the message or structured data in case of tool-generated messages.\n    - **role** (*Literal[\"system\", \"user\", \"assistant\", \"tool\"]*): The role of the message sender, indicating whether the message is from the system, user, assistant, or a tool.\n    - **name** (*Optional[str]*): An optional name for the sender of the message.\n    - **tool_calls** (*Optional[List[ChatToolCall]]*): A list of calls to external tools or functions made within this message, applicable when the role is `tool`.\n    - **tool_call_id** (*Optional[str]*): A unique identifier for the tool call, relevant when integrating external tools or services.\n\n- **model** (*str*, required): The model to be used for generating responses.\n\n- **frequency_penalty** (*float*, optional, default=0.0): Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model’s likelihood to repeat tokens.\n\n- **presence_penalty** (*float*, optional, default=0.0): Positive values penalize new tokens if they are already present in the text so far, decreasing the model’s likelihood to repeat tokens.\n\n- **logprobs** (*bool*, optional, default=False): Indicates whether to include log probabilities for each token in the response.\n\n- **top_logprobs** (*int*, optional, default=0): An integer ranging from 0 to 20. It determines the number of tokens, most likely to appear at each position, to be returned. Each token is accompanied by a log probability. If this parameter is used, 'logprobs' must be set to true.\n\n- **logit_bias** (*Optional[Dict[int, float]]*): Allows specifying biases for or against specific tokens during generation.\n\n- **max_tokens** (*Optional[int]*): The maximum number of tokens to generate in the response(s).\n\n- **n** (*int*, optional, default=1): Number of responses to generate for the given prompt.\n\n- **seed** (*Optional[int]*): A seed for deterministic generation. Using the same seed and inputs will produce the same output.\n\n- **stop** (*Optional[Union[str, List[str]]]*): One or more strings that, if encountered, will cause generation to stop.\n\n- **stream** (*bool*, optional, default=False): If `True`, responses are streamed back as they are generated.\n\n- **temperature** (*float*, optional, default=1.0): Controls the randomness of the generation. Lower values lead to less random completions.\n\n- **top_p** (*float*, optional, default=1.0): Nucleus sampling parameter that controls the diversity of the generated responses.\n\n- **tools** (*Optional[List[ChatTool]]*): Specifies external tools or functions that can be called as part of the chat.\n\n- **tool_choice** (*Optional[Union[Literal[\"none\", \"auto\"], Dict]]*): Controls how tools are selected for use in responses.\n\n- **user** (*Optional[str]*): An optional identifier for the user initiating the request.\n\n- **response_format** (*RequestResponseFormat*, optional): Specifies the format of the response. Can be either \"text\" or \"json_object\", with optional schema definition for JSON responses.\n\n**Returns**\n\n- If `stream` is `False`, a `ChatCompletionResponse` object containing the generated response(s).\n- If `stream` is `True`, a stream of `ChatCompletionStreamResponse` objects, providing a real-time feed of generated responses.\n\n\n**ChatCompletionResponseChoice**\n\n- **finish_reason** (*Optional[Literal[\"stop\", \"length\", \"tool_calls\", \"error\"]]*, optional): The reason the completion process was terminated. It can be due to reaching a stop condition, the maximum length, output of tool calls, or an error.\n\n- **index** (*int*, required, default=0): Indicates the position of this choice within the list of choices.\n\n- **message** (*ChatCompletionMessage*, required): The message part of the chat completion, containing the content of the chat response.\n\n- **logprobs** (*Optional[LogProbs]*, optional): Optionally includes log probabilities for each output token\n\n**ChatCompletionStreamResponseChoice**\n\n- **finish_reason** (*Optional[Literal[\"stop\", \"length\", \"tool_calls\"]]*, optional): Specifies why the streaming completion process ended. Valid reasons are \"stop\", \"length\", and \"tool_calls\".\n\n- **index** (*int*, required, default=0): Indicates the position of this choice within the list of choices.\n\n- **delta** (*ChatCompletionMessage*, required): Represents the incremental update or addition to the chat completion message in the stream.\n\n- **logprobs** (*Optional[LogProbs]*, optional): Optionally includes log probabilities for each output token\n\n**ChatCompletionResponse**\n\n- **id** (*str*, required): A unique identifier for the chat completion session.\n\n- **choices** (*List[ChatCompletionResponseChoice]*, required): A collection of `ChatCompletionResponseChoice` objects, representing the potential responses generated by the model.\n\n- **created** (*int*, required, default=current time): The UNIX timestamp representing when the response was generated.\n\n- **model** (*str*, required): The name of the model used to generate the chat completions.\n\n- **system_fingerprint** (*str*, required): A system-generated fingerprint that uniquely identifies the computational environment.\n\n- **object** (*Literal[\"chat.completion\"]*, required, default=\"chat.completion\"): A string literal indicating the type of object, here always \"chat.completion\".\n\n- **usage** (*UsageInfo*, required, default=empty `UsageInfo` object): Contains information about the API usage for this specific request.\n\n**ChatCompletionStreamResponse**\n\n- **id** (*str*, required): A unique identifier for the streaming chat completion session.\n\n- **choices** (*List[ChatCompletionStreamResponseChoice]*, required): A list of `ChatCompletionStreamResponseChoice` objects, each representing a part of the streaming chat response.\n\n- **created** (*int*, required, default=current time): The creation time of the streaming response, represented as a UNIX timestamp.\n\n- **model** (*str*, required): Specifies the model that was used for generating the streaming chat completions.\n\n- **system_fingerprint** (*str*, required): A unique identifier for the system generating the streaming completions.\n\n- **object** (*Literal[\"chat.completion.chunk\"]*, required, default=\"chat.completion.chunk\"): A literal indicating that this object represents a chunk of a streaming chat completion.\n\n------------------------------------------------\n\n\n**Example**\n\nBelow is an example of using the API to interact with MLC-LLM in Python with Streaming.\n\n.. code:: bash\n\n   import requests\n   import json\n\n   # Get a response using a prompt with streaming\n   payload = {\n    \"model\": \"./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/\",\n    \"messages\": [{\"role\": \"user\", \"content\": \"Write a haiku\"}],\n    \"stream\": True,\n   }\n   with requests.post(\"http://127.0.0.1:8080/v1/chat/completions\", json=payload, stream=True) as r:\n      for chunk in r.iter_content(chunk_size=None):\n         chunk = chunk.decode(\"utf-8\")\n         if \"[DONE]\" in chunk[6:]:\n            break\n         response = json.loads(chunk[6:])\n         content = response[\"choices\"][0][\"delta\"].get(\"content\", \"\")\n         print(content, end=\"\", flush=True)\n   print(\"\\n\")\n\n------------------------------------------------\n\nThere is also support for function calling similar to OpenAI (https://platform.openai.com/docs/guides/function-calling). Below is an example on how to use function calling in Python.\n\n.. code:: bash\n\n   import requests\n   import json\n\n   tools = [\n      {\n         \"type\": \"function\",\n         \"function\": {\n               \"name\": \"get_current_weather\",\n               \"description\": \"Get the current weather in a given location\",\n               \"parameters\": {\n                  \"type\": \"object\",\n                  \"properties\": {\n                     \"location\": {\n                           \"type\": \"string\",\n                           \"description\": \"The city and state, e.g. San Francisco, CA\",\n                     },\n                     \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n                  },\n                  \"required\": [\"location\"],\n               },\n         },\n      }\n   ]\n\n   payload = {\n      \"model\": \"./dist/gorilla-openfunctions-v1-q4f16_1-MLC/\",\n      \"messages\": [\n         {\n               \"role\": \"user\",\n               \"content\": \"What is the current weather in Pittsburgh, PA in fahrenheit?\",\n         }\n      ],\n      \"stream\": False,\n      \"tools\": tools,\n   }\n\n   r = requests.post(\"http://127.0.0.1:8080/v1/chat/completions\", json=payload)\n   print(f\"{r.json()['choices'][0]['message']['tool_calls'][0]['function']}\\n\")\n\n   # Output: {'name': 'get_current_weather', 'arguments': {'location': 'Pittsburgh, PA', 'unit': 'fahrenheit'}}\n\n------------------------------------------------\n\nFunction Calling with streaming is also supported. Below is an example on how to use function calling with streaming in Python.\n\n.. code:: bash\n\n   import requests\n   import json\n\n   tools = [\n      {\n         \"type\": \"function\",\n         \"function\": {\n               \"name\": \"get_current_weather\",\n               \"description\": \"Get the current weather in a given location\",\n               \"parameters\": {\n                  \"type\": \"object\",\n                  \"properties\": {\n                     \"location\": {\n                           \"type\": \"string\",\n                           \"description\": \"The city and state, e.g. San Francisco, CA\",\n                     },\n                     \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n                  },\n                  \"required\": [\"location\"],\n               },\n         },\n      }\n   ]\n\n   payload = {\n      \"model\": \"./dist/gorilla-openfunctions-v1-q4f16_1-MLC/\",\n      \"messages\": [\n         {\n               \"role\": \"user\",\n               \"content\": \"What is the current weather in Pittsburgh, PA and Tokyo, JP in fahrenheit?\",\n         }\n      ],\n      \"stream\": True,\n      \"tools\": tools,\n   }\n\n   with requests.post(\"http://127.0.0.1:8080/v1/chat/completions\", json=payload, stream=True) as r:\n    for chunk in r.iter_content(chunk_size=None):\n        chunk = chunk.decode(\"utf-8\")\n        if \"[DONE]\" in chunk[6:]:\n            break\n        response = json.loads(chunk[6:])\n        content = response[\"choices\"][0][\"delta\"].get(\"content\", \"\")\n        print(f\"{content}\", end=\"\", flush=True)\n   print(\"\\n\")\n\n   # Output: [\"get_current_weather(location='Pittsburgh,PA',unit='fahrenheit')\", \"get_current_weather(location='Tokyo,JP',unit='fahrenheit')\"]\n\n\n.. note::\n   The API is a uniform interface that supports multiple languages. You can also utilize these functionalities in languages other than Python.\n"
  },
  {
    "path": "docs/deploy/webllm.rst",
    "content": ".. _webllm-runtime:\n\nWebLLM Javascript SDK\n=====================\n\n.. contents:: Table of Contents\n   :local:\n   :depth: 2\n\n`WebLLM <https://www.npmjs.com/package/@mlc-ai/web-llm>`_ is a high-performance in-browser LLM\ninference engine, aiming to be the backend of AI-powered web applications and agents.\n\nIt provides a specialized runtime for the web backend of MLCEngine, leverages\n`WebGPU <https://www.w3.org/TR/webgpu/>`_ for local acceleration, offers OpenAI-compatible API,\nand provides built-in support for web workers to separate heavy computation from the UI flow.\n\nPlease checkout the `WebLLM repo <https://github.com/mlc-ai/web-llm>`__ on how to use WebLLM to build\nweb application in Javascript/Typescript. Here we only provide a high-level idea and discuss how to\nuse MLC-LLM to compile your own model to run with WebLLM.\n\nGetting Started\n---------------\n\nTo get started, try out `WebLLM Chat <https://chat.webllm.ai/>`__, which provides a great example\nof integrating WebLLM into a full web application.\n\nA WebGPU-compatible browser is needed to run WebLLM-powered web applications.\nYou can download the latest Google Chrome and use `WebGPU Report <https://webgpureport.org/>`__\nto verify the functionality of WebGPU on your browser.\n\nWebLLM is available as an `npm package <https://www.npmjs.com/package/@mlc-ai/web-llm>`_ and is\nalso CDN-delivered. Try a simple chatbot example in\n`this JSFiddle example <https://jsfiddle.net/neetnestor/4nmgvsa2/>`__ without setup.\n\nYou can also checkout `existing examples <https://github.com/mlc-ai/web-llm/tree/main/examples>`__\non more advanced usage of WebLLM such as JSON mode, streaming, and more.\n\nModel Records in WebLLM\n-----------------------\n\nEach of the model in `WebLLM Chat <https://chat.webllm.ai>`__ is registered as an instance of\n``ModelRecord`` and can be accessed at\n`webllm.prebuiltAppConfig.model_list <https://github.com/mlc-ai/web-llm/blob/main/src/config.ts#L293>`__.\n\nLooking at the most straightforward example `get-started <https://github.com/mlc-ai/web-llm/blob/main/examples/get-started/src/get_started.ts>`__,\nthere are two ways to run a model.\n\nOne can either use the prebuilt model by simply calling ``reload()`` with the ``model_id``:\n\n.. code:: typescript\n\n  const selectedModel = \"Llama-3-8B-Instruct-q4f32_1-MLC\";\n  const engine = await webllm.CreateMLCEngine(selectedModel);\n\nOr one can specify their own model to run by creating a model record:\n\n.. code:: typescript\n\n  const appConfig: webllm.AppConfig = {\n    model_list: [\n      {\n        model: \"https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f32_1-MLC\",\n        model_id: \"Llama-3-8B-Instruct-q4f32_1-MLC\",\n        model_lib:\n          webllm.modelLibURLPrefix +\n          webllm.modelVersion +\n          \"/Llama-3-8B-Instruct-q4f32_1-ctx4k_cs1k-webgpu.wasm\",\n      },\n      // Add your own models here...\n    ],\n  };\n  const selectedModel = \"Llama-3-8B-Instruct-q4f32_1-MLC\";\n  const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(\n    selectedModel,\n    { appConfig: appConfig },\n  );\n\nLooking at the code above, we find that, just like any other platforms supported by MLC-LLM, to\nrun a model on WebLLM, you need:\n\n1. **Model weights** converted to MLC format (e.g. `Llama-3-8B-Instruct-q4f32_1-MLC\n   <https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f32_1-MLC/tree/main>`_.): downloaded through the url ``ModelRecord.model``\n2. **Model library** that comprises the inference logic (see repo `binary-mlc-llm-libs <https://github.com/mlc-ai/binary-mlc-llm-libs/tree/main/web-llm-models>`__): downloaded through the url ``ModelRecord.model_lib``.\n\nIn sections below, we walk you through two examples on how to add your own model besides the ones in\n`webllm.prebuiltAppConfig.model_list <https://github.com/mlc-ai/web-llm/blob/main/src/config.ts#L293>`__.\nBefore proceeding, please verify installation of ``mlc_llm`` and ``tvm``.\n\nVerify Installation for Adding Models\n-------------------------------------\n\n**Step 1. Verify mlc_llm**\n\nWe use the python package ``mlc_llm`` to compile models. This can be installed by\nfollowing :ref:`install-mlc-packages`, either by building from source, or by\ninstalling the prebuilt package. Verify ``mlc_llm`` installation in command line via:\n\n.. code:: bash\n\n    $ mlc_llm --help\n    # You should see help information with this line\n    usage: MLC LLM Command Line Interface. [-h] {compile,convert_weight,gen_config}\n\n.. note::\n    If it runs into error ``command not found: mlc_llm``, try ``python -m mlc_llm --help``.\n\n**Step 2. Verify TVM**\n\nTo compile models, you also need to follow :ref:`install-tvm`.\nHere we verify ``tvm`` quickly with command line (for full verification, see :ref:`tvm-validate`):\n\n.. code:: bash\n\n    $ python -c \"import tvm; print(tvm.__file__)\"\n    /some-path/lib/python3.13/site-packages/tvm/__init__.py\n\n\n.. _webllm-add-model-variant:\n\nBring Your Own Model Variant\n----------------------------\n\nIn cases where the model you are adding is simply a variant of an existing\nmodel, we only need to convert weights and reuse existing model library. For instance:\n\n- Adding ``OpenMistral`` when MLC supports ``Mistral``\n- Adding a ``Llama3`` fine-tuned on a domain-specific task when MLC supports ``Llama3``\n\n\nIn this section, we walk you through adding ``WizardMath-7B-V1.1-q4f16_1`` to the\n`get-started <https://github.com/mlc-ai/web-llm/tree/main/examples/get-started>`__ example.\nAccording to the model's ``config.json`` on `its Huggingface repo <https://huggingface.co/WizardLM/WizardMath-7B-V1.1/blob/main/config.json>`_,\nit reuses the Mistral model architecture.\n\n.. note::\n\n  This section largely replicates :ref:`convert-weights-via-MLC`.\n  See that page for more details. Note that the weights are shared across\n  all platforms in MLC.\n\n**Step 1 Clone from HF and convert_weight**\n\nYou can be under the mlc-llm repo, or your own working directory. Note that all platforms\ncan share the same compiled/quantized weights. See :ref:`compile-command-specification`\nfor specification of ``convert_weight``.\n\n.. code:: shell\n\n    # Create directory\n    mkdir -p dist/models && cd dist/models\n    # Clone HF weights\n    git lfs install\n    git clone https://huggingface.co/WizardLM/WizardMath-7B-V1.1\n    cd ../..\n    # Convert weight\n    mlc_llm convert_weight ./dist/models/WizardMath-7B-V1.1/ \\\n        --quantization q4f16_1 \\\n        -o dist/WizardMath-7B-V1.1-q4f16_1-MLC\n\n**Step 2 Generate MLC Chat Config**\n\nUse ``mlc_llm gen_config`` to generate ``mlc-chat-config.json`` and process tokenizers.\nSee :ref:`compile-command-specification` for specification of ``gen_config``.\n\n.. code:: shell\n\n    mlc_llm gen_config ./dist/models/WizardMath-7B-V1.1/ \\\n        --quantization q4f16_1 --conv-template wizard_coder_or_math \\\n        -o dist/WizardMath-7B-V1.1-q4f16_1-MLC/\n\nFor the ``conv-template``, `conversation_template.py <https://github.com/mlc-ai/mlc-llm/tree/main/python/mlc_llm/conversation_template>`__\ncontains a full list of conversation templates that MLC provides. You can also manually modify the ``mlc-chat-config.json`` to\nadd your customized conversation template.\n\n**Step 3 Upload weights to HF**\n\n.. code:: shell\n\n    # First, please create a repository on Hugging Face.\n    # With the repository created, run\n    git lfs install\n    git clone https://huggingface.co/my-huggingface-account/my-wizardMath-weight-huggingface-repo\n    cd my-wizardMath-weight-huggingface-repo\n    cp path/to/mlc-llm/dist/WizardMath-7B-V1.1-q4f16_1-MLC/* .\n    git add . && git commit -m \"Add wizardMath model weights\"\n    git push origin main\n\nAfter successfully following all steps, you should end up with a Huggingface repo similar to\n`WizardMath-7B-V1.1-q4f16_1-MLC <https://huggingface.co/mlc-ai/WizardMath-7B-V1.1-q4f16_1-MLC>`__,\nwhich includes the converted/quantized weights, the ``mlc-chat-config.json``, and tokenizer files.\n\n\n**Step 4 Register as a ModelRecord**\n\nFinally, we modify the code snippet for\n`get-started <https://github.com/mlc-ai/web-llm/blob/main/examples/get-started/src/get_started.ts>`__\npasted above.\n\nWe simply specify the Huggingface link as ``model``, while reusing the ``model_lib`` for\n``Mistral-7B``.\n\n.. code:: typescript\n\n  const appConfig: webllm.AppConfig = {\n    model_list: [\n      {\n        model: \"https://huggingface.co/mlc-ai/WizardMath-7B-V1.1-q4f16_1-MLC\",\n        model_id: \"WizardMath-7B-V1.1-q4f16_1-MLC\",\n        model_lib:\n          webllm.modelLibURLPrefix +\n          webllm.modelVersion +\n          \"/Mistral-7B-Instruct-v0.3-q4f16_1-ctx4k_cs1k-webgpu.wasm\",\n      },\n      // Add your own models here...\n    ],\n  };\n\n  const selectedModel = \"WizardMath-7B-V1.1-q4f16_1\"\n  const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(\n    selectedModel,\n    { appConfig: appConfig },\n  );\n\nNow, running the ``get-started`` example will use the ``WizardMath`` model you just added.\nSee `get-started's README <https://github.com/mlc-ai/web-llm/tree/main/examples/get-started#webllm-get-started-app>`__\non how to run it.\n\n\nBring Your Own Model Library\n----------------------------\n\nA model library is specified by:\n\n - The model architecture (e.g. ``llama-3``, ``gpt-neox``, ``phi-3``)\n - Quantization (e.g. ``q4f16_1``, ``q0f32``)\n - Metadata (e.g. ``context_window_size``, ``sliding_window_size``, ``prefill-chunk-size``), which affects memory planning (currently only ``prefill-chunk-size`` affects the compiled model)\n - Platform (e.g. ``cuda``, ``webgpu``, ``iOS``)\n\nIn cases where the model you want to run is not compatible with the provided MLC\nprebuilt model libraries (e.g. having a different quantization, a different\nmetadata spec, or even a different model architecture), you need to build your\nown model library.\n\nIn this section, we walk you through adding ``RedPajama-INCITE-Chat-3B-v1`` to the\n`get-started <https://github.com/mlc-ai/web-llm/tree/main/examples/get-started>`__ example.\n\nThis section largely replicates :ref:`compile-model-libraries`. See that page for\nmore details, specifically the ``WebGPU`` option.\n\n**Step 0. Install dependencies**\n\nTo compile model libraries for webgpu, you need to :ref:`build mlc_llm from source <mlcchat_build_from_source>`.\nBesides, you also need to follow :ref:`install-web-build`. Otherwise, it would run into error:\n\n.. code:: text\n\n    RuntimeError: Cannot find libraries: wasm_runtime.bc\n\n**Step 1. Clone from HF and convert_weight**\n\nYou can be under the mlc-llm repo, or your own working directory. Note that all platforms\ncan share the same compiled/quantized weights.\n\n.. code:: shell\n\n    # Create directory\n    mkdir -p dist/models && cd dist/models\n    # Clone HF weights\n    git lfs install\n    git clone https://huggingface.co/togethercomputer/RedPajama-INCITE-Chat-3B-v1\n    cd ../..\n    # Convert weight\n    mlc_llm convert_weight ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \\\n        --quantization q4f16_1 \\\n        -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\n\n**Step 2. Generate mlc-chat-config and compile**\n\nA model library is specified by:\n\n - The model architecture (e.g. ``llama-2``, ``gpt-neox``)\n - Quantization (e.g. ``q4f16_1``, ``q0f32``)\n - Metadata (e.g. ``context_window_size``, ``sliding_window_size``, ``prefill-chunk-size``), which affects memory planning\n - Platform (e.g. ``cuda``, ``webgpu``, ``iOS``)\n\nAll these knobs are specified in ``mlc-chat-config.json`` generated by ``gen_config``.\n\n.. code:: shell\n\n    # 1. gen_config: generate mlc-chat-config.json and process tokenizers\n    mlc_llm gen_config ./dist/models/RedPajama-INCITE-Chat-3B-v1/ \\\n        --quantization q4f16_1 --conv-template redpajama_chat \\\n        -o dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/\n    # 2. compile: compile model library with specification in mlc-chat-config.json\n    mlc_llm compile ./dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/mlc-chat-config.json \\\n        --device webgpu -o dist/libs/RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm\n\n.. note::\n    When compiling larger models like ``Llama-3-8B``, you may want to add ``--prefill_chunk_size 1024``\n    to decrease memory usage. Otherwise, during runtime, you may run into issues like:\n\n    .. code:: text\n\n        TypeError: Failed to execute 'createBuffer' on 'GPUDevice': Failed to read the 'size' property from\n        'GPUBufferDescriptor': Value is outside the 'unsigned long long' value range.\n\n\n**Step 3. Distribute model library and model weights**\n\nAfter following the steps above, you should end up with:\n\n.. code:: shell\n\n    ~/mlc-llm > ls dist/libs\n      RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm  # ===> the model library\n\n    ~/mlc-llm > ls dist/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\n      mlc-chat-config.json                             # ===> the chat config\n      tensor-cache.json                               # ===> the model weight info\n      params_shard_0.bin                               # ===> the model weights\n      params_shard_1.bin\n      ...\n      tokenizer.json                                   # ===> the tokenizer files\n      tokenizer_config.json\n\nUpload the ``RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm`` to a github repository (for us,\nit is in `binary-mlc-llm-libs <https://github.com/mlc-ai/binary-mlc-llm-libs>`__). Then\nupload the ``RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC`` to a Huggingface repo:\n\n.. code:: shell\n\n    # First, please create a repository on Hugging Face.\n    # With the repository created, run\n    git lfs install\n    git clone https://huggingface.co/my-huggingface-account/my-redpajama3b-weight-huggingface-repo\n    cd my-redpajama3b-weight-huggingface-repo\n    cp path/to/mlc-llm/dist/RedPajama-INCITE-Instruct-3B-v1-q4f16_1-MLC/* .\n    git add . && git commit -m \"Add redpajama-3b instruct model weights\"\n    git push origin main\n\nThis would result in something like `RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC\n<https://huggingface.co/mlc-ai/RedPajama-INCITE-Chat-3B-v1-q4f16_1-MLC/tree/main>`_.\n\n**Step 4. Register as a ModelRecord**\n\nFinally, we are able to run the model we added in WebLLM's `get-started <https://github.com/mlc-ai/web-llm/tree/main/examples/get-started>`__:\n\n.. code:: typescript\n\n  const myAppConfig: AppConfig = {\n    model_list: [\n      // Other records here omitted...\n      {\n        \"model\": \"https://huggingface.co/my-hf-account/my-redpajama3b-weight-huggingface-repo/resolve/main/\",\n        \"model_id\": \"RedPajama-INCITE-Instruct-3B-v1\",\n        \"model_lib\": \"https://raw.githubusercontent.com/my-gh-account/my-repo/main/RedPajama-INCITE-Chat-3B-v1-q4f16_1-webgpu.wasm\",\n        \"required_features\": [\"shader-f16\"],\n      },\n    ]\n  }\n\n  const selectedModel = \"RedPajama-INCITE-Instruct-3B-v1\";\n  const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(\n    selectedModel,\n    { appConfig: appConfig },\n  );\n\nNow, running the ``get-started`` example will use the ``RedPajama`` model you just added.\nSee `get-started's README <https://github.com/mlc-ai/web-llm/tree/main/examples/get-started#webllm-get-started-app>`__\non how to run it.\n"
  },
  {
    "path": "docs/get_started/introduction.rst",
    "content": ".. _introduction-to-mlc-llm:\n\nIntroduction to MLC LLM\n=======================\n\n.. contents:: Table of Contents\n    :local:\n    :depth: 2\n\nMLC LLM is a machine learning compiler and high-performance deployment\nengine for large language models.  The mission of this project is to enable everyone to develop,\noptimize, and deploy AI models natively on everyone's platforms. \n\nThis page is a quick tutorial to introduce how to try out MLC LLM, and the steps to\ndeploy your own models with MLC LLM.\n\nInstallation\n------------\n\n:ref:`MLC LLM <install-mlc-packages>` is available via pip.\nIt is always recommended to install it in an isolated conda virtual environment.\n\nTo verify the installation, activate your virtual environment, run\n\n.. code:: bash\n\n  python -c \"import mlc_llm; print(mlc_llm.__path__)\"\n\nYou are expected to see the installation path of MLC LLM Python package.\n\n\nChat CLI\n--------\n\nAs the first example, we try out the chat CLI in MLC LLM with 4-bit quantized 8B Llama-3 model.\nYou can run MLC chat through a one-liner command:\n\n.. code:: bash\n\n    mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC\n\nIt may take 1-2 minutes for the first time running this command.\nAfter waiting, this command launch a chat interface where you can enter your prompt and chat with the model.\n\n.. code::\n\n  You can use the following special commands:\n  /help               print the special commands\n  /exit               quit the cli\n  /stats              print out the latest stats (token/sec)\n  /reset              restart a fresh chat\n  /set [overrides]    override settings in the generation config. For example,\n                        `/set temperature=0.5;max_gen_len=100;stop=end,stop`\n                        Note: Separate stop words in the `stop` option with commas (,).\n  Multi-line input: Use escape+enter to start a new line.\n\n  user: What's the meaning of life\n  assistant:\n  What a profound and intriguing question! While there's no one definitive answer, I'd be happy to help you explore some perspectives on the meaning of life.\n\n  The concept of the meaning of life has been debated and...\n\n\nThe figure below shows what run under the hood of this chat CLI command.\nFor the first time running the command, there are three major phases.\n\n- **Phase 1. Pre-quantized weight download.** This phase automatically downloads pre-quantized Llama-3 model from `Hugging Face <https://huggingface.co/mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC>`_ and saves it to your local cache directory.\n- **Phase 2. Model compilation.** This phase automatically optimizes the Llama-3 model to accelerate model inference on GPU with techniques of machine learning compilation in `Apache TVM <https://llm.mlc.ai/docs/install/tvm.html>`_ compiler, and generate the binary model library that enables the execution language models on your local GPU.\n- **Phase 3. Chat runtime.** This phase consumes the model library built in phase 2 and the model weights downloaded in phase 1, launches a platform-native chat runtime to drive the execution of Llama-3 model.\n\nWe cache the pre-quantized model weights and compiled model library locally.\nTherefore, phase 1 and 2 will only execute **once** over multiple runs.\n\n.. figure:: /_static/img/project-workflow.svg\n  :width: 700\n  :align: center\n  :alt: Project Workflow\n\n  Workflow in MLC LLM\n\n.. note::\n\n  If you want to enable tensor parallelism to run LLMs on multiple GPUs,\n  please specify argument ``--overrides \"tensor_parallel_shards=$NGPU\"``.\n  For example,\n\n  .. code:: shell\n\n    mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC --overrides \"tensor_parallel_shards=2\"\n\n.. _introduction-to-mlc-llm-python-api:\n\nPython API\n----------\n\nIn the second example, we run the Llama-3 model with the chat completion Python API of MLC LLM.\nYou can save the code below into a Python file and run it.\n\n.. code:: python\n\n  from mlc_llm import MLCEngine\n\n  # Create engine\n  model = \"HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC\"\n  engine = MLCEngine(model)\n\n  # Run chat completion in OpenAI API.\n  for response in engine.chat.completions.create(\n      messages=[{\"role\": \"user\", \"content\": \"What is the meaning of life?\"}],\n      model=model,\n      stream=True,\n  ):\n      for choice in response.choices:\n          print(choice.delta.content, end=\"\", flush=True)\n  print(\"\\n\")\n\n  engine.terminate()\n\n.. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-engine-api.jpg\n  :width: 500\n  :align: center\n\n  MLC LLM Python API\n\nThis code example first creates an :class:`mlc_llm.MLCEngine` instance with the 4-bit quantized Llama-3 model.\n**We design the Python API** :class:`mlc_llm.MLCEngine` **to align with OpenAI API**,\nwhich means you can use :class:`mlc_llm.MLCEngine` in the same way of using\n`OpenAI's Python package <https://github.com/openai/openai-python?tab=readme-ov-file#usage>`_\nfor both synchronous and asynchronous generation.\n\nIn this code example, we use the synchronous chat completion interface and iterate over\nall the stream responses.\nIf you want to run without streaming, you can run\n\n.. code:: python\n\n  response = engine.chat.completions.create(\n      messages=[{\"role\": \"user\", \"content\": \"What is the meaning of life?\"}],\n      model=model,\n      stream=False,\n  )\n  print(response)\n\nYou can also try different arguments supported in `OpenAI chat completion API <https://platform.openai.com/docs/api-reference/chat/create>`_.\nIf you would like to do concurrent asynchronous generation, you can use :class:`mlc_llm.AsyncMLCEngine` instead.\n\n.. note::\n\n  If you want to enable tensor parallelism to run LLMs on multiple GPUs,\n  please specify argument ``model_config_overrides`` in MLCEngine constructor.\n  For example,\n\n  .. code:: python\n\n    from mlc_llm import MLCEngine\n    from mlc_llm.serve.config import EngineConfig\n\n    model = \"HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC\"\n    engine = MLCEngine(\n        model,\n        engine_config=EngineConfig(tensor_parallel_shards=2),\n    )\n\n\nREST Server\n-----------\n\nFor the third example, we launch a REST server to serve the 4-bit quantized Llama-3 model\nfor OpenAI chat completion requests. The server can be launched in command line with\n\n.. code:: bash\n\n  mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC\n\nThe server is hooked at ``http://127.0.0.1:8000`` by default, and you can use ``--host`` and ``--port``\nto set a different host and port.\nWhen the server is ready (showing ``INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)``),\nwe can open a new shell and send a cURL request via the following command:\n\n.. code:: bash\n\n  curl -X POST \\\n    -H \"Content-Type: application/json\" \\\n    -d '{\n          \"model\": \"HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC\",\n          \"messages\": [\n              {\"role\": \"user\", \"content\": \"Hello! Our project is MLC LLM. What is the name of our project?\"}\n          ]\n    }' \\\n    http://127.0.0.1:8000/v1/chat/completions\n\nThe server will process this request and send back the response.\nSimilar to :ref:`introduction-to-mlc-llm-python-api`, you can pass argument ``\"stream\": true``\nto request for stream responses.\n\n.. note::\n\n  If you want to enable tensor parallelism to run LLMs on multiple GPUs,\n  please specify argument ``--overrides \"tensor_parallel_shards=$NGPU\"``.\n  For example,\n\n  .. code:: shell\n\n    mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC --overrides \"tensor_parallel_shards=2\"\n\n.. _introduction-deploy-your-own-model:\n\nDeploy Your Own Model\n---------------------\n\nSo far we have been using pre-converted models weights from Hugging Face.\nThis section introduces the core workflow regarding how you can *run your own models with MLC LLM*.\n\nWe use the `Phi-2 <https://huggingface.co/microsoft/phi-2>`_ as the example model.\nAssuming the Phi-2 model is downloaded and placed under ``models/phi-2``,\nthere are two major steps to prepare your own models.\n\n- **Step 1. Generate MLC config.** The first step is to generate the configuration file of MLC LLM.\n\n  .. code:: bash\n\n    export LOCAL_MODEL_PATH=models/phi-2   # The path where the model resides locally.\n    export MLC_MODEL_PATH=dist/phi-2-MLC/  # The path where to place the model processed by MLC.\n    export QUANTIZATION=q0f16              # The choice of quantization.\n    export CONV_TEMPLATE=phi-2             # The choice of conversation template.\n    mlc_llm gen_config $LOCAL_MODEL_PATH \\\n        --quantization $QUANTIZATION \\\n        --conv-template $CONV_TEMPLATE \\\n        -o $MLC_MODEL_PATH\n\n  The config generation command takes in the local model path, the target path of MLC output,\n  the conversation template name in MLC and the quantization name in MLC.\n  Here the quantization ``q0f16`` means float16 without quantization,\n  and the conversation template ``phi-2`` is the Phi-2 model's template in MLC.\n\n  If you want to enable tensor parallelism on multiple GPUs, add argument\n  ``--tensor-parallel-shards $NGPU`` to the config generation command.\n\n  - `The full list of supported quantization in MLC <https://github.com/mlc-ai/mlc-llm/blob/main/python/mlc_llm/quantization/quantization.py#L29>`_. You can try different quantization methods with MLC LLM. Typical quantization methods are ``q4f16_1`` for 4-bit group quantization, ``q4f16_ft`` for 4-bit FasterTransformer format quantization.\n  - `The full list of conversation template in MLC <https://github.com/mlc-ai/mlc-llm/blob/main/python/mlc_llm/interface/gen_config.py#L276>`_.\n\n- **Step 2. Convert model weights.** In this step, we convert the model weights to MLC format.\n\n  .. code:: bash\n\n    mlc_llm convert_weight $LOCAL_MODEL_PATH \\\n      --quantization $QUANTIZATION \\\n      -o $MLC_MODEL_PATH\n\n  This step consumes the raw model weights and converts them to for MLC format.\n  The converted weights will be stored under ``$MLC_MODEL_PATH``,\n  which is the same directory where the config file generated in Step 1 resides.\n\nNow, we can try to run your own model with chat CLI:\n\n.. code:: bash\n\n  mlc_llm chat $MLC_MODEL_PATH\n\nFor the first run, model compilation will be triggered automatically to optimize the\nmodel for GPU accelerate and generate the binary model library.\nThe chat interface will be displayed after model JIT compilation finishes.\nYou can also use this model in Python API, MLC serve and other use scenarios.\n\n(Optional) Compile Model Library\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nIn previous sections, model libraries are compiled when the :class:`mlc_llm.MLCEngine` launches,\nwhich is what we call \"JIT (Just-in-Time) model compilation\".\nIn some cases, it is beneficial to explicitly compile the model libraries.\nWe can deploy LLMs with reduced dependencies by shipping the library for deployment without going through compilation.\nIt will also enable advanced options such as cross-compiling the libraries for web and mobile deployments.\n\n\nBelow is an example command of compiling model libraries in MLC LLM:\n\n.. code:: bash\n\n  export MODEL_LIB=$MLC_MODEL_PATH/lib.so  # \".dylib\" for Intel Macs.\n                                            # \".dll\" for Windows.\n                                            # \".wasm\" for web.\n                                            # \".tar\" for iPhone/Android.\n  mlc_llm compile $MLC_MODEL_PATH -o $MODEL_LIB\n\nAt runtime, we need to specify this model library path to use it. For example,\n\n.. code:: bash\n\n  # For chat CLI\n  mlc_llm chat $MLC_MODEL_PATH --model-lib $MODEL_LIB\n  # For REST server\n  mlc_llm serve $MLC_MODEL_PATH --model-lib $MODEL_LIB\n\n.. code:: python\n\n  from mlc_llm import MLCEngine\n\n  # For Python API\n  model = \"models/phi-2\"\n  model_lib = \"models/phi-2/lib.so\"\n  engine = MLCEngine(model, model_lib=model_lib)\n\n:ref:`compile-model-libraries` introduces the model compilation command in detail,\nwhere you can find instructions and example commands to compile model to different\nhardware backends, such as WebGPU, iOS and Android.\n\nUniversal Deployment\n--------------------\n\nMLC LLM is a high-performance universal deployment solution for large language models,\nto enable native deployment of any large language models with native APIs with compiler acceleration\nSo far, we have gone through several examples running on a local GPU environment.\nThe project supports multiple kinds of GPU backends.\n\nYou can use `--device` option in compilation and runtime to pick a specific GPU backend.\nFor example, if you have an NVIDIA or AMD GPU, you can try to use the option below\nto run chat through the vulkan backend. Vulkan-based LLM applications run in less typical\nenvironments (e.g. SteamDeck).\n\n.. code:: bash\n\n    mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC --device vulkan\n\nThe same core LLM runtime engine powers all the backends, enabling the same model to be deployed across backends as\nlong as they fit within the memory and computing budget of the corresponding hardware backend.\nWe also leverage machine learning compilation to build backend-specialized optimizations to\nget out the best performance on the targetted backend when possible, and reuse key insights and optimizations\nacross backends we support.\n\nPlease checkout the what to do next sections below to find out more about different deployment scenarios,\nsuch as WebGPU-based browser deployment, mobile and other settings.\n\nSummary and What to Do Next\n---------------------------\n\nTo briefly summarize this page,\n\n- We went through three examples (chat CLI, Python API, and REST server) of MLC LLM,\n- we introduced how to convert model weights for your own models to run with MLC LLM, and (optionally) how to compile your models.\n- We also discussed the universal deployment capability of MLC LLM.\n\nNext, please feel free to check out the pages below for quick start examples and more detailed information\non specific platforms\n\n- :ref:`Quick start examples <quick-start>` for Python API, chat CLI, REST server, web browser, iOS and Android.\n- Depending on your use case, check out our API documentation and tutorial pages:\n\n  - :ref:`webllm-runtime`\n  - :ref:`deploy-rest-api`\n  - :ref:`deploy-cli`\n  - :ref:`deploy-python-engine`\n  - :ref:`deploy-ios`\n  - :ref:`deploy-android`\n  - :ref:`deploy-ide-integration`\n\n- :ref:`Convert model weight to MLC format <convert-weights-via-MLC>`, if you want to run your own models.\n- :ref:`Compile model libraries <compile-model-libraries>`, if you want to deploy to web/iOS/Android or control the model optimizations.\n- Report any problem or ask any question: open new issues in our `GitHub repo <https://github.com/mlc-ai/mlc-llm/issues>`_.\n"
  },
  {
    "path": "docs/get_started/quick_start.rst",
    "content": ".. _quick-start:\n\nQuick Start\n===========\n\nExamples\n--------\n\nTo begin with, try out MLC LLM support for int4-quantized Llama3 8B.\nIt is recommended to have at least 6GB free VRAM to run it.\n\n.. tabs::\n\n  .. tab:: Python\n\n    **Install MLC LLM**. :ref:`MLC LLM <install-mlc-packages>` is available via pip.\n    It is always recommended to install it in an isolated conda virtual environment.\n\n    **Run chat completion in Python.** The following Python script showcases the Python API of MLC LLM:\n\n    .. code:: python\n\n      from mlc_llm import MLCEngine\n\n      # Create engine\n      model = \"HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC\"\n      engine = MLCEngine(model)\n\n      # Run chat completion in OpenAI API.\n      for response in engine.chat.completions.create(\n          messages=[{\"role\": \"user\", \"content\": \"What is the meaning of life?\"}],\n          model=model,\n          stream=True,\n      ):\n          for choice in response.choices:\n              print(choice.delta.content, end=\"\", flush=True)\n      print(\"\\n\")\n\n      engine.terminate()\n\n    .. Todo: link the colab notebook when ready:\n\n    **Documentation and tutorial.** Python API reference and its tutorials are :ref:`available online <deploy-python-engine>`.\n\n    .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-engine-api.jpg\n      :width: 600\n      :align: center\n\n      MLC LLM Python API\n\n  .. tab:: REST Server\n\n    **Install MLC LLM**. :ref:`MLC LLM <install-mlc-packages>` is available via pip.\n    It is always recommended to install it in an isolated conda virtual environment.\n\n    **Launch a REST server.** Run the following command from command line to launch a REST server at ``http://127.0.0.1:8000``.\n\n    .. code:: shell\n\n      mlc_llm serve HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC\n\n    **Send requests to server.** When the server is ready (showing ``INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)``),\n    open a new shell and send a request via the following command:\n\n    .. code:: shell\n\n      curl -X POST \\\n        -H \"Content-Type: application/json\" \\\n        -d '{\n              \"model\": \"HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC\",\n              \"messages\": [\n                  {\"role\": \"user\", \"content\": \"Hello! Our project is MLC LLM. What is the name of our project?\"}\n              ]\n        }' \\\n        http://127.0.0.1:8000/v1/chat/completions\n\n    **Documentation and tutorial.** Check out :ref:`deploy-rest-api` for the REST API reference and tutorial.\n    Our REST API has complete OpenAI API support.\n\n    .. figure:: https://raw.githubusercontent.com/mlc-ai/web-data/main/images/mlc-llm/tutorials/python-serve-request.jpg\n      :width: 600\n      :align: center\n\n      Send HTTP request to REST server in MLC LLM\n\n  .. tab:: Command Line\n\n    **Install MLC LLM**. :ref:`MLC LLM <install-mlc-packages>` is available via pip.\n    It is always recommended to install it in an isolated conda virtual environment.\n\n    For Windows/Linux users, make sure to have latest :ref:`Vulkan driver <vulkan_driver>` installed.\n\n    **Run in command line**.\n\n    .. code:: bash\n\n      mlc_llm chat HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC\n\n\n    If you are using windows/linux/steamdeck and would like to use vulkan,\n    we recommend installing necessary vulkan loader dependency via conda\n    to avoid vulkan not found issues.\n\n    .. code:: bash\n\n      conda install -c conda-forge gcc libvulkan-loader\n\n\n  .. tab:: Web Browser\n\n    `WebLLM <https://webllm.mlc.ai/#chat-demo>`__. MLC LLM generates performant code for WebGPU and WebAssembly,\n    so that LLMs can be run locally in a web browser without server resources.\n\n    **Download pre-quantized weights**. This step is self-contained in WebLLM.\n\n    **Download pre-compiled model library**. WebLLM automatically downloads WebGPU code to execute.\n\n    **Check browser compatibility**. The latest Google Chrome provides WebGPU runtime and `WebGPU Report <https://webgpureport.org/>`__ as a useful tool to verify WebGPU capabilities of your browser.\n\n    .. figure:: https://blog.mlc.ai/img/redpajama/web.gif\n      :width: 300\n      :align: center\n\n      MLC LLM on Web\n\n  .. tab:: iOS\n\n    **Install MLC Chat iOS**. It is available on AppStore:\n\n    .. image:: https://developer.apple.com/assets/elements/badges/download-on-the-app-store.svg\n      :width: 135\n      :target: https://apps.apple.com/us/app/mlc-chat/id6448482937\n\n    |\n\n    **Note**. The larger model might take more VRAM, try start with smaller models first.\n\n    **Tutorial and source code**. The source code of the iOS app is fully `open source <https://github.com/mlc-ai/mlc-llm/tree/main/ios>`__,\n    and a :ref:`tutorial <deploy-ios>` is included in documentation.\n\n    .. figure:: https://blog.mlc.ai/img/redpajama/ios.gif\n      :width: 300\n      :align: center\n\n      MLC Chat on iOS\n\n  .. tab:: Android\n\n    **Install MLC Chat Android**. A prebuilt is available as an APK:\n\n    .. image:: https://seeklogo.com/images/D/download-android-apk-badge-logo-D074C6882B-seeklogo.com.png\n      :width: 135\n      :target: https://github.com/mlc-ai/binary-mlc-llm-libs/releases/download/Android-09262024/mlc-chat.apk\n\n    |\n\n    **Note**. The larger model might take more VRAM, try start with smaller models first.\n    The demo is tested on\n\n    - Samsung S23 with Snapdragon 8 Gen 2 chip\n    - Redmi Note 12 Pro with Snapdragon 685\n    - Google Pixel phones\n\n    **Tutorial and source code**. The source code of the android app is fully `open source <https://github.com/mlc-ai/mlc-llm/tree/main/android>`__,\n    and a :ref:`tutorial <deploy-android>` is included in documentation.\n\n    .. figure:: https://blog.mlc.ai/img/android/android-recording.gif\n      :width: 300\n      :align: center\n\n      MLC LLM on Android\n\n\nWhat to Do Next\n---------------\n\n- Check out :ref:`introduction-to-mlc-llm` for the introduction of a complete workflow in MLC LLM.\n- Depending on your use case, check out our API documentation and tutorial pages:\n\n  - :ref:`webllm-runtime`\n  - :ref:`deploy-rest-api`\n  - :ref:`deploy-cli`\n  - :ref:`deploy-python-engine`\n  - :ref:`deploy-ios`\n  - :ref:`deploy-android`\n  - :ref:`deploy-ide-integration`\n\n- :ref:`convert-weights-via-MLC`, if you want to run your own models.\n- :ref:`compile-model-libraries`, if you want to deploy to web/iOS/Android or control the model optimizations.\n- Report any problem or ask any question: open new issues in our `GitHub repo <https://github.com/mlc-ai/mlc-llm/issues>`_.\n"
  },
  {
    "path": "docs/index.rst",
    "content": "👋 Welcome to MLC LLM\n=====================\n\n`Discord <https://discord.gg/9Xpy2HGBuD>`_ | `GitHub <https://github.com/mlc-ai/mlc-llm>`_\n\n\n\n\nMLC LLM is a machine learning compiler and high-performance deployment\nengine for large language models.  The mission of this project is to enable\neveryone to develop, optimize, and deploy AI models natively on everyone's platforms. \n\nQuick Start\n-----------\n\nCheck out :ref:`quick-start` for quick start examples of using MLC LLM.\n\nIntroduction to MLC LLM\n-----------------------\n\nCheck out :ref:`introduction-to-mlc-llm` for the introduction and tutorial of a complete workflow in MLC LLM.\n\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Get Started\n   :hidden:\n\n   get_started/quick_start.rst\n   get_started/introduction.rst\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Build and Deploy Apps\n   :hidden:\n\n   deploy/webllm.rst\n   deploy/rest.rst\n   deploy/cli.rst\n   deploy/python_engine.rst\n   deploy/ios.rst\n   deploy/android.rst\n   deploy/ide_integration.rst\n   deploy/mlc_chat_config.rst\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Compile Models\n   :hidden:\n\n   compilation/convert_weights.rst\n   compilation/compile_models.rst\n   compilation/package_libraries_and_weights.rst\n   compilation/define_new_models.rst\n   compilation/configure_quantization.rst\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Dependency Installation\n   :hidden:\n\n   install/tvm.rst\n   install/mlc_llm.rst\n   install/conda.rst\n   install/gpu.rst\n   install/emcc.rst\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Microserving API\n   :hidden:\n\n   microserving/tutorial.rst\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Community\n   :hidden:\n\n   community/guideline.rst\n   community/faq.rst\n\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Privacy\n   :hidden:\n\n   privacy.rst\n"
  },
  {
    "path": "docs/install/conda.rst",
    "content": "Install Conda\n=============\n\nMLC LLM does not depend on, but generally recommends conda as a generic dependency manager, primarily because it creates unified cross-platform experience to make windows/Linux/macOS development equally easy. Moreover, conda is python-friendly and provides all the python packages needed for MLC LLM, such as numpy.\n\n.. contents:: Table of Contents\n    :depth: 2\n\n\nInstall Miniconda\n-----------------\n\n**Use installer.** Miniconda, a minimal distribution of conda, comes with out-of-box installer across Windows/macOS/Linux. Please refer to its `official website <https://docs.conda.io/en/latest/miniconda.html#latest-miniconda-installer-links>`_ link for detailed instructions.\n\n**Set libmamba as the dependency solver.** The default dependency solver in conda could be slow in certain scenarios, and it is always recommended to upgrade it to libmamba, a faster solver.\n\n.. code-block:: bash\n   :caption: Set libmamba as the default solver\n\n   # update conda\n   conda update --yes -n base -c defaults conda\n   # install `conda-libmamba-solver`\n   conda install --yes -n base conda-libmamba-solver\n   # set it as the default solver\n   conda config --set solver libmamba\n\n.. note::\n    Conda is a generic dependency manager, which is not necessarily related to any Python distributions.\n    In fact, some of our tutorials recommends to use conda to install cmake, git and rust for its unified experience across OS platforms.\n\n\nValidate installation\n---------------------\n\n**Step 1. Check conda-arch mismatch.** Nowadays macOS runs on two different architectures: arm64 and x86_64, which could particularly lead to many misuses in MLC LLM, where the error message hints about \"architecture mismatch\". Use the following command to make sure particular conda architecture is installed accordingly:\n\n.. code-block:: bash\n   :caption: Check conda architecture\n\n   >>> conda info | grep platform\n   # for arm mac\n   platform : osx-arm64\n   # for x86 mac\n   platform : osx-64\n\n**Step 2. Check conda virtual environment.** If you have installed python in your conda virtual environment, make sure conda, Python and pip are all from this environment:\n\n.. code-block:: bash\n   :caption: Check conda virtual environment (macOS, Linux)\n\n   >>> echo $CONDA_PREFIX\n   /.../miniconda3/envs/mlc-doc-venv\n   >>> which python\n   /.../miniconda3/envs/mlc-doc-venv/bin/python\n   >>> which pip\n   /.../miniconda3/envs/mlc-doc-venv/bin/pip\n\n.. code-block:: bat\n   :caption: Check conda virtual environment (Windows)\n\n   >>> echo $Env:CONDA_PREFIX\n   \\...\\miniconda3\\envs\\mlc-doc-venv\n   >>> Get-Command python.exe\n   \\...\\miniconda3\\envs\\mlc-doc-venv\\bin\\python.exe\n   >>> Get-Command pip.exe\n   \\...\\miniconda3\\envs\\mlc-doc-venv\\bin\\pip.exe\n"
  },
  {
    "path": "docs/install/emcc.rst",
    "content": ".. _install-web-build:\n\nInstall Wasm Build Environment\n==============================\n\nThis page describes the steps to setup build environment for WebAssembly and WebGPU builds.\n\nStep 1: Install EMSDK\n---------------------\n\nEmscripten is an LLVM-based compiler that compiles C/C++ source code to WebAssembly.\nWe need to install emscripten for webgpu build.\n\n- Please follow the installation instruction `here <https://emscripten.org/docs/getting_started/downloads.html#installation-instructions-using-the-emsdk-recommended>`__\n  to install the latest emsdk.\n- Source path/to/emsdk_env.sh so emcc is reachable from PATH and the command emcc works.\n\nValidate that emcc is accessible in shell\n\n.. code:: bash\n\n    emcc --version\n\n.. note::\n    We recently found that using the latest ``emcc`` version may run into issues during runtime. Use\n    ``./emsdk install 3.1.56`` instead of ``./emsdk install latest`` for now as a workaround.\n\n    The error may look like\n\n    .. code:: text\n\n        Init error, LinkError: WebAssembly.instantiate(): Import #6 module=\"wasi_snapshot_preview1\"\n        function=\"proc_exit\": function import requires a callable\n\n\nStep 2: Set TVM_SOURCE_DIR and MLC_LLM_SOURCE_DIR\n-------------------------------------------------\n\nWe need to set a path to a tvm source in order to build tvm runtime.\nNote that you do not need to build TVM from the source. The source here is only used to build the web runtime component.\nSet environment variable in your shell startup profile in to point to ``3rdparty/tvm`` (if preferred, you could also\npoint to your own TVM address if you installed TVM from source).\n\nBesides, we also need to set ``MLC_LLM_SOURCE_DIR`` so that we can locate ``mlc_wasm_runtime.bc`` when compiling a model library wasm.\n\n.. code:: bash\n\n    export TVM_SOURCE_DIR=/path/to/3rdparty/tvm\n    export MLC_LLM_SOURCE_DIR=/path/to/mlc-llm\n\n\nStep 3: Prepare Wasm Runtime\n----------------------------\n\nFirst, we need to obtain a copy of the mlc-llm source code for the setup script\n\n.. code:: bash\n\n    git clone https://github.com/mlc-ai/mlc-llm.git --recursive\n    cd mlc-llm\n\nNow we can prepare wasm runtime using the script in mlc-llm repo\n\n.. code:: bash\n\n    ./web/prep_emcc_deps.sh\n\nWe can then validate the outcome\n\n.. code:: bash\n\n    >>> echo ${TVM_SOURCE_DIR}\n\n    /path/set/in/step2\n\n    >>> ls -l ${TVM_SOURCE_DIR}/web/dist/wasm/*.bc\n\n    tvmjs_support.bc\n    wasm_runtime.bc\n    webgpu_runtime.bc\n"
  },
  {
    "path": "docs/install/gpu.rst",
    "content": "GPU Drivers and SDKs\n====================\n\n.. contents:: Table of Contents\n    :depth: 2\n\nMLC LLM is a universal deployment solution that allows efficient CPU/GPU code generation without AutoTVM-based performance tuning. This section focuses on generic GPU environment setup and troubleshooting.\n\nCUDA\n----\n\nCUDA is required to compile and run models with CUDA backend.\n\nInstallation\n^^^^^^^^^^^^\n\nIf you have a NVIDIA GPU and you want to use models compiled with CUDA\nbackend, you should install CUDA, which can be downloaded from\n`here <https://developer.nvidia.com/cuda-downloads>`__.\n\nValidate Installation\n^^^^^^^^^^^^^^^^^^^^^\n\nTo verify you have correctly installed CUDA runtime and NVIDIA driver, run ``nvidia-smi`` in command line and see if you can get the GPU information.\n\nROCm\n----\n\nROCm is required to compile and run models with ROCm backend.\n\nInstallation\n^^^^^^^^^^^^\n\nRight now MLC LLM only supports ROCm 6.1/6.2.\nIf you have AMD GPU and you want to use models compiled with ROCm\nbackend, you should install ROCm from `here <https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.2.0/install/quick-start.html>`__.\n\nValidate Installation\n^^^^^^^^^^^^^^^^^^^^^\n\nTo verify you have correctly installed ROCm, run ``rocm-smi`` in command line.\nIf you see the list of AMD devices printed out in a table, it means the ROCm is correctly installed.\n\n.. _vulkan_driver:\n\nVulkan Driver\n-------------\n\nInstallation\n^^^^^^^^^^^^\n\nTo run pre-trained models (e.g. pulled from MLC-AI's Hugging Face repository) compiled with Vulkan backend, you are expected to install Vulkan driver on your machine.\n\nPlease check `this\npage <https://www.vulkan.org/tools#vulkan-gpu-resources>`__ and find the\nVulkan driver according to your GPU vendor.\n\nAMD Radeon and Radeon PRO\n#########################\n\nFor AMD Radeon and Radeon PRO users, please download AMD's drivers from official website (`Linux <https://www.amd.com/en/support/linux-drivers>`__ / `Windows <https://www.amd.com/en/support>`__).\nFor Linux users, after you installed the ``amdgpu-install`` package, you can follow the instructions in its `documentation <https://amdgpu-install.readthedocs.io/en/latest/install-script.html>`__ to install\nthe driver. We recommend you installing ROCr OpenCL and PRO Vulkan (proprietary) for best performance, which can be done by running the following command:\n\n.. code:: bash\n\n   amdgpu-install --usecase=graphics,opencl --opencl=rocr --vulkan=pro --no-32\n\nValidate Installation\n^^^^^^^^^^^^^^^^^^^^^\n\nTo verify whether Vulkan installation is successful or not, you are encouraged to install ``vulkaninfo``, below are the instructions to install ``vulkaninfo`` on different platforms:\n\n.. tabs ::\n\n   .. code-tab :: bash Ubuntu/Debian\n\n      sudo apt-get update\n      sudo apt-get install vulkan-tools\n\n   .. code-tab :: bash Windows\n\n      # It comes with your GPU driver\n\n   .. code-tab :: bash Fedora\n\n      sudo dnf install vulkan-tools\n\n   .. code-tab :: bash Arch Linux\n\n      sudo pacman -S vulkan-tools\n      # Arch Linux has maintained an awesome wiki page for Vulkan which you can refer to for troubleshooting: https://wiki.archlinux.org/title/Vulkan\n\n   .. code-tab :: bash Other Distributions\n\n      # Please install Vulkan SDK for your platform\n      # https://vulkan.lunarg.com/sdk/home\n\n\nAfter installation, you can run ``vulkaninfo`` in command line and see if you can get the GPU information.\n\n.. note::\n   WSL support for Windows is work-in-progress at the moment. Please do not use WSL on Windows to run Vulkan.\n\nVulkan SDK\n----------\n\nVulkan SDK is required for compiling models to Vulkan backend. To build TVM compiler from source, you will need to install Vulkan SDK as a dependency, but our :doc:`pre-built wheels <../install/mlc_llm>` already ships with Vulkan SDK.\n\nCheck Vulkan SDK installation guide according to your platform:\n\n.. tabs ::\n\n   .. tab :: Windows\n\n      `Getting Started with the Windows Tarball Vulkan SDK <https://vulkan.lunarg.com/doc/sdk/latest/windows/getting_started.html>`__\n\n   .. tab :: Linux\n\n      For Ubuntu user, please check\n      `Getting Started with the Ubuntu Vulkan SDK <https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started_ubuntu.html>`__\n\n      For other Linux distributions, please check\n      `Getting Started with the Linux Tarball Vulkan SDK <https://vulkan.lunarg.com/doc/sdk/latest/linux/getting_started.html>`__\n\n   .. tab :: Mac\n\n      `Getting Started with the macOS Vulkan SDK <https://vulkan.lunarg.com/doc/sdk/latest/mac/getting_started.html>`__\n\nPlease refer to installation and setup page for next steps to build TVM from source.\n\nOpenCL SDK\n----------\n\nOpenCL SDK is only required when you want to build your own models for OpenCL backend. Please refer to `OpenCL's Github Repository <https://github.com/KhronosGroup/OpenCL-SDK>`__ for installation guide of OpenCL-SDK.\n\nOrange Pi 5 (RK3588 based SBC)\n------------------------------\n\nOpenCL SDK and Mali GPU driver is required to compile and run models for OpenCL backend.\n\nInstallation\n^^^^^^^^^^^^\n\n* Download and install the Ubuntu 22.04 for your board from `here <https://github.com/Joshua-Riek/ubuntu-rockchip/releases/tag/v1.22>`__\n\n* Download and install ``libmali-g610.so``\n\n.. code-block:: bash\n\n   cd /usr/lib && sudo wget https://github.com/JeffyCN/mirrors/raw/libmali/lib/aarch64-linux-gnu/libmali-valhall-g610-g6p0-x11-wayland-gbm.so\n\n* Check if file ``mali_csffw.bin`` exist under path ``/lib/firmware``, if not download it with command:\n\n.. code-block:: bash\n\n   cd /lib/firmware && sudo wget https://github.com/JeffyCN/mirrors/raw/libmali/firmware/g610/mali_csffw.bin\n\n* Download OpenCL ICD loader and manually add libmali to ICD\n\n.. code-block:: bash\n\n   sudo apt update\n   sudo apt install mesa-opencl-icd\n   sudo mkdir -p /etc/OpenCL/vendors\n   echo \"/usr/lib/libmali-valhall-g610-g6p0-x11-wayland-gbm.so\" | sudo tee /etc/OpenCL/vendors/mali.icd\n\n* Download and install ``libOpenCL``\n\n.. code-block:: bash\n\n   sudo apt install ocl-icd-opencl-dev\n\n* Download and install dependencies for Mali OpenCL\n\n.. code-block:: bash\n\n   sudo apt install libxcb-dri2-0 libxcb-dri3-0 libwayland-client0 libwayland-server0 libx11-xcb1\n\n* Download and install clinfo to check if OpenCL successfully installed\n\n.. code-block:: bash\n\n   sudo apt install clinfo\n\nValidate Installation\n^^^^^^^^^^^^^^^^^^^^^\n\nTo verify you have correctly installed OpenCL runtime and Mali GPU driver, run ``clinfo`` in command line and see if you can get the GPU information.\nYou are expect to see the following information:\n\n.. code-block:: bash\n\n   $ clinfo\n   arm_release_ver: g13p0-01eac0, rk_so_ver: 3\n   Number of platforms                               2\n      Platform Name                                   ARM Platform\n      Platform Vendor                                 ARM\n      Platform Version                                OpenCL 2.1 v1.g6p0-01eac0.2819f9d4dbe0b5a2f89c835d8484f9cd\n      Platform Profile                                FULL_PROFILE\n      ...\n"
  },
  {
    "path": "docs/install/mlc_llm.rst",
    "content": ".. _install-mlc-packages:\n\nInstall MLC LLM Python Package\n==============================\n\n.. contents:: Table of Contents\n    :local:\n    :depth: 2\n\nMLC LLM Python Package can be installed directly from a prebuilt developer package, or built from source.\n\nOption 1. Prebuilt Package\n--------------------------\n\nWe provide nightly built pip wheels for MLC-LLM via pip.\nSelect your operating system/compute platform and run the command in your terminal:\n\n.. note::\n    ❗ Whenever using Python, it is highly recommended to use **conda** to manage an isolated Python environment to avoid missing dependencies, incompatible versions, and package conflicts.\n    Please make sure your conda environment has Python and pip installed.\n\n.. tabs::\n\n    .. tab:: Linux\n\n        .. tabs::\n\n            .. tab:: CPU\n\n                .. code-block:: bash\n\n                    conda activate your-environment\n                    python -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cpu mlc-ai-nightly-cpu\n\n            .. tab:: CUDA 12.8\n\n                .. code-block:: bash\n\n                    conda activate your-environment\n                    python -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cu128 mlc-ai-nightly-cu128\n\n            .. tab:: CUDA 13.0\n\n                .. code-block:: bash\n\n                    conda activate your-environment\n                    python -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cu130 mlc-ai-nightly-cu130\n\n            .. tab:: ROCm 6.1\n\n                .. code-block:: bash\n\n                    conda activate your-environment\n                    python -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-rocm61 mlc-ai-nightly-rocm61\n\n            .. tab:: ROCm 6.2\n\n                .. code-block:: bash\n\n                    conda activate your-environment\n                    python -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-rocm62 mlc-ai-nightly-rocm62\n\n            .. tab:: Vulkan\n\n                Supported in all Linux packages. Checkout the following instructions\n                to install the latest vulkan loader to avoid vulkan not found issue.\n\n                .. code-block:: bash\n\n                    conda install -c conda-forge gcc libvulkan-loader\n\n        .. note::\n            We need git-lfs in the system, you can install it via\n\n            .. code-block:: bash\n\n                conda install -c conda-forge git-lfs\n\n            If encountering issues with GLIBC not found, please install the latest glibc in conda:\n\n            .. code-block:: bash\n\n                conda install -c conda-forge libstdcxx-ng\n\n            Besides, we would recommend using Python 3.13; so if you are creating a new environment,\n            you could use the following command:\n\n            .. code-block:: bash\n\n                conda create --name mlc-prebuilt  python=3.13\n\n    .. tab:: macOS\n\n        .. tabs::\n\n            .. tab:: CPU + Metal\n\n                .. code-block:: bash\n\n                    conda activate your-environment\n                    python -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cpu mlc-ai-nightly-cpu\n\n        .. note::\n\n            Always check if conda is installed properly in macOS using the command below:\n\n            .. code-block:: bash\n\n                conda info | grep platform\n\n            It should return \"osx-64\" for Mac with Intel chip, and \"osx-arm64\" for Mac with Apple chip.\n            We need git-lfs in the system, you can install it via\n\n            .. code-block:: bash\n\n                conda install -c conda-forge git-lfs\n\n    .. tab:: Windows\n\n        .. tabs::\n\n            .. tab:: CPU + Vulkan\n\n                .. code-block:: bash\n\n                    conda activate your-environment\n                    python -m pip install --pre -U -f https://mlc.ai/wheels mlc-llm-nightly-cpu mlc-ai-nightly-cpu\n\n        .. note::\n            Please make sure your conda environment comes with python and pip.\n            Make sure you also install the following packages,\n            vulkan loader, clang, git and git-lfs to enable proper automatic download\n            and jit compilation.\n\n            .. code-block:: bash\n\n                conda install -c conda-forge clang libvulkan-loader git-lfs git\n\n            If encountering the error below:\n\n            .. code-block:: bash\n\n                FileNotFoundError: Could not find module 'path\\to\\site-packages\\tvm\\tvm.dll' (or one of its dependencies). Try using the full path with constructor syntax.\n\n            It is likely `zstd`, a dependency to LLVM, was missing. Please use the command below to get it installed:\n\n            .. code-block:: bash\n\n                conda install zstd\n\n\nThen you can verify installation in command line:\n\n.. code-block:: bash\n\n    python -c \"import mlc_llm; print(mlc_llm)\"\n    # Prints out: <module 'mlc_llm' from '/path-to-env/lib/python3.13/site-packages/mlc_llm/__init__.py'>\n\n|\n\n.. _mlcchat_build_from_source:\n\nOption 2. Build from Source\n---------------------------\n\nWe also provide options to build mlc runtime libraries ``mlc_llm`` from source.\nThis step is useful when you want to make modification or obtain a specific version of mlc runtime.\n\n\n**Step 1. Set up build dependency.** To build from source, you need to ensure that the following build dependencies are satisfied:\n\n* CMake >= 3.24\n* Git\n* `Rust and Cargo <https://www.rust-lang.org/tools/install>`_, required by Hugging Face's tokenizer\n* One of the GPU runtimes:\n\n    * CUDA >= 11.8 (NVIDIA GPUs)\n    * Metal (Apple GPUs)\n    * Vulkan (NVIDIA, AMD, Intel GPUs)\n\n.. code-block:: bash\n    :caption: Set up build dependencies in Conda\n\n    # make sure to start with a fresh environment\n    conda env remove -n mlc-chat-venv\n    # create the conda environment with build dependency\n    conda create -n mlc-chat-venv -c conda-forge \\\n        \"cmake>=3.24\" \\\n        rust \\\n        git \\\n        python=3.13\n    # enter the build environment\n    conda activate mlc-chat-venv\n\n.. note::\n    For runtime, :doc:`TVM </install/tvm>` compiler is not a dependency for MLCChat CLI or Python API. Only TVM's runtime is required, which is automatically included in `3rdparty/tvm <https://github.com/mlc-ai/mlc-llm/tree/main/3rdparty>`_.\n    However, if you would like to compile your own models, you need to follow :doc:`TVM </install/tvm>`.\n\n**Step 2. Configure and build.** A standard git-based workflow is recommended to download MLC LLM, after which you can specify build requirements with our lightweight config generation tool:\n\n.. code-block:: bash\n    :caption: Configure and build\n\n    # clone from GitHub\n    git clone --recursive https://github.com/mlc-ai/mlc-llm.git && cd mlc-llm/\n    # create build directory\n    mkdir -p build && cd build\n    # generate build configuration\n    python ../cmake/gen_cmake_config.py\n    # build mlc_llm libraries\n    cmake .. && make -j $(nproc) && cd ..\n\n**Step 3. Install via Python.** We recommend that you install ``mlc_llm`` as a Python package, giving you\naccess to ``mlc_llm.compile``, ``mlc_llm.MLCEngine``, and the CLI.\nThere are two ways to do so:\n\n    .. tabs ::\n\n       .. code-tab :: bash Install via environment variable\n\n          export MLC_LLM_SOURCE_DIR=/path-to-mlc-llm\n          export PYTHONPATH=$MLC_LLM_SOURCE_DIR/python:$PYTHONPATH\n          alias mlc_llm=\"python -m mlc_llm\"\n\n       .. code-tab :: bash Install via pip local project\n\n          conda activate your-own-env\n          which python # make sure python is installed, expected output: path_to_conda/envs/your-own-env/bin/python\n          cd /path-to-mlc-llm/python\n          pip install -e .\n\n**Step 4. Validate installation.** You may validate if MLC libarires and mlc_llm CLI is compiled successfully using the following command:\n\n.. code-block:: bash\n    :caption: Validate installation\n\n    # expected to see `libmlc_llm.so` and `libtvm_runtime.so`\n    ls -l ./build/\n    # expected to see help message\n    mlc_llm chat -h\n\nFinally, you can verify installation in command line. You should see the path you used to build from source with:\n\n.. code:: bash\n\n   python -c \"import mlc_llm; print(mlc_llm)\"\n"
  },
  {
    "path": "docs/install/tvm.rst",
    "content": ".. _install-tvm:\n\nInstall TVM Compiler\n==========================\n\n.. contents:: Table of Contents\n    :local:\n    :depth: 2\n\n`TVM Unity <https://discuss.tvm.apache.org/t/establish-tvm-unity-connection-a-technical-strategy/13344>`__, the latest development in Apache TVM, is required to build MLC LLM. Its features include:\n\n- High-performance CPU/GPU code generation instantly without tuning;\n- Dynamic shape and symbolic shape tracking by design;\n- Supporting both inference and training;\n- Productive python-first compiler implementation. As a concrete example, MLC LLM compilation is implemented in pure python using its API.\n\nTVM can be installed directly from a prebuilt developer package, or built from source.\n\n.. _tvm-prebuilt-package:\n\nOption 1. Prebuilt Package\n--------------------------\n\nA nightly prebuilt Python package of Apache TVM is provided.\n\n.. note::\n    ❗ Whenever using Python, it is highly recommended to use **conda** to manage an isolated Python environment to avoid missing dependencies, incompatible versions, and package conflicts.\n\n.. tabs::\n\n   .. tab:: Linux\n\n      .. tabs::\n\n         .. tab:: CPU\n\n            .. code-block:: bash\n\n              conda activate your-environment\n              python -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cpu\n\n         .. tab:: CUDA 12.8\n\n            .. code-block:: bash\n\n              conda activate your-environment\n              python -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu128\n\n         .. tab:: CUDA 13.0\n\n            .. code-block:: bash\n\n              conda activate your-environment\n              python -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cu130\n\n         .. tab:: ROCm 6.1\n\n            .. code-block:: bash\n\n              conda activate your-environment\n              python -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-rocm61\n\n         .. tab:: ROCm 6.2\n\n            .. code-block:: bash\n\n              conda activate your-environment\n              python -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-rocm62\n\n         .. tab:: Vulkan\n\n            Supported in all Linux packages.\n\n      .. note::\n\n        If encountering issues with GLIBC not found, please install the latest glibc in conda:\n\n        .. code-block:: bash\n\n          conda install -c conda-forge libstdcxx-ng\n\n   .. tab:: macOS\n\n      .. tabs::\n\n         .. tab:: CPU + Metal\n\n            .. code-block:: bash\n\n              conda activate your-environment\n              python -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cpu\n\n        .. note::\n\n          Always check if conda is installed properly in macOS using the command below:\n\n          .. code-block:: bash\n\n            conda info | grep platform\n\n          It should return \"osx-64\" for Mac with Intel chip, and \"osx-arm64\" for Mac with Apple chip.\n\n   .. tab:: Windows\n\n      .. tabs::\n\n         .. tab:: CPU + Vulkan\n\n            .. code-block:: bash\n\n              conda activate your-environment\n              python -m pip install --pre -U -f https://mlc.ai/wheels mlc-ai-nightly-cpu\n\n      .. note::\n        Make sure you also install vulkan loader and clang to avoid vulkan\n        not found error or clang not found(needed for jit compile)\n\n        .. code-block:: bash\n\n            conda install -c conda-forge clang libvulkan-loader\n\n        If encountering the error below:\n\n        .. code-block:: bash\n\n            FileNotFoundError: Could not find module 'path\\to\\site-packages\\tvm\\tvm.dll' (or one of its dependencies). Try using the full path with constructor syntax.\n\n        It is likely `zstd`, a dependency to LLVM, was missing. Please use the command below to get it installed:\n\n        .. code-block:: bash\n\n            conda install zstd\n\n.. _tvm-build-from-source:\n\nOption 2. Build from Source\n---------------------------\n\nWhile it is generally recommended to always use the prebuilt TVM, if you require more customization, you may need to build it from source. **NOTE.** this should only be attempted if you are familiar with the intricacies of C++, CMake, LLVM, Python, and other related systems.\n\n.. collapse:: Details\n\n    **Step 1. Set up build dependency.** To build from source, you need to ensure that the following build dependencies are met:\n\n    - CMake >= 3.24\n    - LLVM >= 15\n      - For please install LLVM>=17 for ROCm 6.1 and LLVM>=18 for ROCm 6.2.\n    - Git\n    - (Optional) CUDA >= 11.8 (targeting NVIDIA GPUs)\n    - (Optional) Metal (targeting Apple GPUs such as M1 and M2)\n    - (Optional) Vulkan (targeting NVIDIA, AMD, Intel and mobile GPUs)\n    - (Optional) OpenCL (targeting NVIDIA, AMD, Intel and mobile GPUs)\n\n    .. note::\n        - To target NVIDIA GPUs, either CUDA or Vulkan is required (CUDA is recommended);\n        - For AMD and Intel GPUs, Vulkan is necessary;\n        - When targeting Apple (macOS, iOS, iPadOS), Metal is a mandatory dependency;\n        - Some Android devices only support OpenCL, but most of them support Vulkan.\n\n    To easiest way to manage dependency is via conda, which maintains a set of toolchains including LLVM across platforms. To create the environment of those build dependencies, one may simply use:\n\n    .. code-block:: bash\n        :caption: Set up build dependencies in conda\n\n        # make sure to start with a fresh environment\n        conda env remove -n tvm-build-venv\n        # create the conda environment with build dependency\n        conda create -n tvm-build-venv -c conda-forge \\\n            \"llvmdev>=15\" \\\n            \"cmake>=3.24\" \\\n            git \\\n            python=3.13\n        # enter the build environment\n        conda activate tvm-build-venv\n\n    **Step 2. Configure and build.** Standard git-based workflow are recommended to download Apache TVM, and then specify build requirements in ``config.cmake``:\n\n    .. code-block:: bash\n        :caption: Download TVM from GitHub\n\n        # clone from GitHub\n        git clone --recursive https://github.com/apache/tvm.git && cd tvm\n        # create the build directory\n        rm -rf build && mkdir build && cd build\n        # specify build requirements in `config.cmake`\n        cp ../cmake/config.cmake .\n\n    We want to specifically tweak the following flags by appending them to the end of the configuration file:\n\n    .. code-block:: bash\n        :caption: Configure build in ``config.cmake``\n\n        # controls default compilation flags\n        echo \"set(CMAKE_BUILD_TYPE RelWithDebInfo)\" >> config.cmake\n        # LLVM is a must dependency\n        echo \"set(USE_LLVM \\\"llvm-config --ignore-libllvm --link-static\\\")\" >> config.cmake\n        echo \"set(HIDE_PRIVATE_SYMBOLS ON)\" >> config.cmake\n        # GPU SDKs, turn on if needed\n        echo \"set(USE_CUDA   OFF)\" >> config.cmake\n        echo \"set(USE_ROCM   OFF)\" >> config.cmake\n        echo \"set(USE_METAL  OFF)\" >> config.cmake\n        echo \"set(USE_VULKAN OFF)\" >> config.cmake\n        echo \"set(USE_OPENCL OFF)\" >> config.cmake\n        # Below are options for CUDA, turn on if needed\n        # CUDA_ARCH is the cuda compute capability of your GPU.\n        # Examples: 89 for 4090, 90a for H100/H200, 100a for B200.\n        # Reference: https://developer.nvidia.com/cuda-gpus\n        echo \"set(CMAKE_CUDA_ARCHITECTURES YOUR_CUDA_COMPUTE_CAPABILITY_HERE)\" >> config.cmake\n        echo \"set(USE_CUBLAS ON)\" >> config.cmake\n        echo \"set(USE_CUTLASS ON)\" >> config.cmake\n        echo \"set(USE_THRUST ON)\" >> config.cmake\n        echo \"set(USE_NVTX ON)\" >> config.cmake\n        # Below is the option for ROCM, turn on if needed\n        echo \"set(USE_HIPBLAS ON)\" >> config.cmake\n\n    .. note::\n        ``HIDE_PRIVATE_SYMBOLS`` is a configuration option that enables the ``-fvisibility=hidden`` flag. This flag helps prevent potential symbol conflicts between TVM and PyTorch. These conflicts arise due to the frameworks shipping LLVMs of different versions.\n\n        `CMAKE_BUILD_TYPE <https://cmake.org/cmake/help/latest/variable/CMAKE_BUILD_TYPE.html>`_ controls default compilation flag:\n\n        - ``Debug`` sets ``-O0 -g``\n        - ``RelWithDebInfo`` sets ``-O2 -g -DNDEBUG`` (recommended)\n        - ``Release`` sets ``-O3 -DNDEBUG``\n\n    Once ``config.cmake`` is edited accordingly, kick off build with the commands below:\n\n    .. code-block:: bash\n        :caption: Build ``libtvm`` using cmake and cmake\n\n        cmake .. && make -j $(nproc) && cd ..\n\n    A success build should produce ``libtvm`` and ``libtvm_runtime`` under ``/path-tvm/build/`` directory.\n\n    Leaving the build environment ``tvm-build-venv``, there are two ways to install the successful build into your environment:\n\n    .. tabs ::\n\n       .. code-tab :: bash Install via environment variable\n\n          export PYTHONPATH=/path-to-tvm/python:$PYTHONPATH\n\n       .. code-tab :: bash Install via pip local project\n\n          conda activate your-own-env\n          conda install python # make sure python is installed\n          cd /path-to-tvm/python\n          pip install -e .\n\n.. `|` adds a blank line\n\n|\n\n.. _tvm-validate:\n\nValidate TVM Installation\n-------------------------\n\nUsing a compiler infrastructure with multiple language bindings could be error-prone.\nTherefore, it is highly recommended to validate TVM installation before use.\n\n**Step 1. Locate TVM Python package.** The following command can help confirm that TVM is properly installed as a python package and provide the location of the TVM python package:\n\n.. code-block:: bash\n\n    >>> python -c \"import tvm; print(tvm.__file__)\"\n    /some-path/lib/python3.13/site-packages/tvm/__init__.py\n\n**Step 2. Confirm which TVM library is used.** When maintaining multiple build or installation of TVM, it becomes important to double check if the python package is using the proper ``libtvm`` with the following command:\n\n.. code-block:: bash\n\n    >>> python -c \"import tvm; print(tvm.base._LIB)\"\n    <CDLL '/some-path/lib/python3.13/site-packages/tvm/libtvm.dylib', handle 95ada510 at 0x1030e4e50>\n\n**Step 3. Reflect TVM build option.** Sometimes when downstream application fails, it could likely be some mistakes with a wrong TVM commit, or wrong build flags. To find it out, the following commands will be helpful:\n\n.. code-block:: bash\n\n    >>> python -c \"import tvm; print('\\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))\"\n    ... # Omitted less relevant options\n    GIT_COMMIT_HASH: 4f6289590252a1cf45a4dc37bce55a25043b8338\n    HIDE_PRIVATE_SYMBOLS: ON\n    USE_LLVM: llvm-config --link-static\n    LLVM_VERSION: 15.0.7\n    USE_VULKAN: OFF\n    USE_CUDA: OFF\n    CUDA_VERSION: NOT-FOUND\n    USE_OPENCL: OFF\n    USE_METAL: ON\n    USE_ROCM: OFF\n\n.. note::\n    ``GIT_COMMIT_HASH`` indicates the exact commit of the TVM build, and it can be found on GitHub via ``https://github.com/mlc-ai/relax/commit/$GIT_COMMIT_HASH``.\n\n**Step 4. Check device detection.** Sometimes it could be helpful to understand if TVM could detect your device at all with the following commands:\n\n.. code-block:: bash\n\n    >>> python -c \"import tvm; print(tvm.metal().exist)\"\n    True # or False\n    >>> python -c \"import tvm; print(tvm.cuda().exist)\"\n    False # or True\n    >>> python -c \"import tvm; print(tvm.vulkan().exist)\"\n    False # or True\n\nPlease note that the commands above verify the presence of an actual device on the local machine for the TVM runtime (not the compiler) to execute properly. However, TVM compiler can perform compilation tasks without requiring a physical device. As long as the necessary toolchain, such as NVCC, is available, TVM supports cross-compilation even in the absence of an actual device.\n"
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\n\npushd %~dp0\n\nREM Command file for Sphinx documentation\n\nif \"%SPHINXBUILD%\" == \"\" (\n    set SPHINXBUILD=sphinx-build\n)\nset SOURCEDIR=.\nset BUILDDIR=_build\n\n%SPHINXBUILD% >NUL 2>NUL\nif errorlevel 9009 (\n    echo.\n    echo.The 'sphinx-build' command was not found. Make sure you have Sphinx\n    echo.installed, then set the SPHINXBUILD environment variable to point\n    echo.to the full path of the 'sphinx-build' executable. Alternatively you\n    echo.may add the Sphinx directory to PATH.\n    echo.\n    echo.If you don't have Sphinx installed, grab it from\n    echo.https://www.sphinx-doc.org/\n    exit /b 1\n)\n\nif \"%1\" == \"\" goto help\n\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\ngoto end\n\n:help\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\n\n:end\npopd\n"
  },
  {
    "path": "docs/microserving/tutorial.rst",
    "content": "Implement LLM Cross-engine Orchestration Patterns\n======================================================================\n\nIn this tutorial, we will introduce how to implement LLM cross-engine\norchestration patterns, like prefill-decode disaggregation, in MLC-LLM\nvia microserving API. Aiming to make disaggregated serving programmable,\nMicroServing provides a new RISC-style approach to design LLM serving\nAPI at sub-request level. It enables programmable cross-engine serving\npatterns in a few lines of python code. For more information of\nmicroserving API, check out\nhttps://blog.mlc.ai/2025/01/07/microserving-llm-engines.\n\nBelow is an example of prefill-decode disaggregation implementation. An\nLLM cross-engine orchestration pattern is implemented in a router, which\ndispatches original OpenAI-style completion requests to a chain of\nmicroserving API calls. In this code example, we create a subclass of\nRouter (which includes wrappers for calling microserving APIs), and\noverride ``translate_request`` function. The ``translate_request``\nfunction takes in a request and a unique identifier of the request\n(``request_id``), and returns an AsyncGenerator of response. We launch\nthe CustomRouter and 2 engines, each of which has tensor parallel degree\n2. Engine 0 is prefill engine and engine 1 is decode engine.\n\n.. code:: python\n\n   from mlc_llm.router import Router\n   from mlc_llm.protocol import openai_api_protocol\n   from typing import Any, AsyncGenerator\n   from mlc_llm.serve.entrypoints import microserving_entrypoints\n   from mlc_llm.interface.router import serve\n\n   import aiohttp\n\n   class CustomRouter(Router):\n       async def translate_request(self, request: openai_api_protocol.CompletionRequest, request_id: str) -&gt; AsyncGenerator[openai_api_protocol.CompletionResponse, Any]:\n           pass\n\n\n   serve(\n       model=\"/path/to/model\", # replace this with actual path\n       model_lib=\"/path/to/model_lib\", # replace this with actual path\n       router_host=\"127.0.0.1\",\n       router_port=9123,\n       endpoint_hosts=[\"127.0.0.1\", \"127.0.0.1\"],\n       endpoint_ports=[9124,9125],\n       endpoint_num_gpus=[2,2],\n       enable_prefix_cache=False,\n       router_type=CustomRouter,\n   )\n\nIn the ``translate_request`` function, we first assign ``request_id`` to\nrequest.user, and later the request id will be passed as an argument to\nthe microserving API.\n\n.. code:: python\n\n   # we will pass request_id as an argument in microserving API calls\n   request.user = request_id\n\n\nNext, call ``prep_recv`` on the decode engine to prepare KV entries for\nreceiving from remote. ``end=-1`` means that we will let the prefill\nengine prefill all except the last token, which makes sure that the\nprefill engine does not need sampling logic. ``prep_recv`` returns\naddress to receive KV from remote and matched prefix length. For\nsimplicity, we do not enable prefix cache in the tutorial, so we only\nneed the kv address here.\n\n.. code:: python\n\n   async with aiohttp.ClientSession(\n       timeout=aiohttp.ClientTimeout(total=3 * 3600), trust_env=True\n   ) as session:\n       decode_start = len(request.prompt) -1\n       # 1. Ask decode engine to prepare KV entries to receive from prefill engine\n       prep_recv_request = microserving_entrypoints.PrepRecvRequest(\n           **request.model_dump(), end=decode_start\n       )\n       (\n           kv_addr_info,\n           _,\n       ) = await self.send_prepare_receive(\n           session=session,\n           request=prep_recv_request,\n           server_url=self.server_urls[1], # engine 0 is prefill, engine 1 is decode. Here is decode engine\n       )\n\nThen, call ``remote_send`` on the prefill engine to compute and send KV\nto decode engine. ``recv_rank=self.device_id_starts[1]`` means that we\nare sending KV to engine 1 (decode engine).\n\n.. code:: python\n\n\n   # 2. Ask prefill engine to send KV to decode engine\n   remote_send_request = microserving_entrypoints.RemoteSendRequest(\n       **request.model_dump(),\n       begin=0,\n       end=decode_start,\n       kv_addr_info=kv_addr_info,\n       recv_rank=self.device_id_starts[1], # the rank of decode engine\n   )\n   await self.send_remote_send(\n       session=session,\n       request=remote_send_request,\n       server_url=self.server_urls[0], # prefill engine\n   )\n\nFinally, call ``start_generate`` on the decode engine to start\ngenerating tokens. ``begin=decode_start`` means we will prefill the last\ntoken in the prompt and start decoding. Notably, the decode process of\nthe request may be preempted. In such case, we yield None, so that the\nrouter will rerun the ``translate_request`` function.\n\n.. code:: python\n\n   # 3. Start decoding\n   start_generate_request = microserving_entrypoints.StartGenerateRequest(\n       **request.model_dump(),\n       begin=decode_start,\n   )\n   async for response in self.send_start_generate(\n       session=session,\n       request=start_generate_request,\n       server_url=self.server_urls[1],\n   ):\n       if len(response.choices) &gt; 0:\n           finish_reason = response.choices[0].finish_reason\n           if finish_reason == \"preempt\":\n               yield None\n       yield response\n\nBringing everything together, the complete code is as below:\n\n.. code:: python\n\n   from mlc_llm.router import Router\n   from mlc_llm.protocol import openai_api_protocol\n   from typing import Any, AsyncGenerator\n   from mlc_llm.serve.entrypoints import microserving_entrypoints\n   from mlc_llm.interface.router import serve\n\n   import aiohttp\n   class CustomRouter(Router):\n       async def translate_request(self, request: openai_api_protocol.CompletionRequest, request_id: str) -&gt; AsyncGenerator[openai_api_protocol.CompletionResponse, Any]:\n           # we will pass request_id as an argument in microserving API calls\n           request.user = request_id\n\n           async with aiohttp.ClientSession(\n               timeout=aiohttp.ClientTimeout(total=3 * 3600), trust_env=True\n           ) as session:\n               decode_start = len(request.prompt) -1\n               # 1. Ask decode engine to prepare KV entries to receive from prefill engine\n               prep_recv_request = microserving_entrypoints.PrepRecvRequest(\n                   **request.model_dump(), end=decode_start\n               )\n               (\n                   kv_addr_info,\n                   _,\n               ) = await self.send_prepare_receive(\n                   session=session,\n                   request=prep_recv_request,\n                   server_url=self.server_urls[1], # engine 0 is prefill, engine 1 is decode. Here is decode engine\n               )\n               # 2. Ask prefill engine to send KV to decode engine\n               remote_send_request = microserving_entrypoints.RemoteSendRequest(\n                   **request.model_dump(),\n                   begin=0,\n                   end=decode_start,\n                   kv_addr_info=kv_addr_info,\n                   recv_rank=self.device_id_starts[1], # the rank of decode engine\n               )\n               await self.send_remote_send(\n                   session=session,\n                   request=remote_send_request,\n                   server_url=self.server_urls[0], # prefill engine\n               )\n               # 3. Start decoding\n               start_generate_request = microserving_entrypoints.StartGenerateRequest(\n                   **request.model_dump(),\n                   begin=decode_start,\n               )\n               async for response in self.send_start_generate(\n                   session=session,\n                   request=start_generate_request,\n                   server_url=self.server_urls[1],\n               ):\n                   if len(response.choices) &gt; 0:\n                       finish_reason = response.choices[0].finish_reason\n                       if finish_reason == \"preempt\":\n                           yield None\n                   yield response\n\n\n   serve(\n       model=\"/path/to/model\", # replace this with actual path\n       model_lib=\"/path/to/model_lib\", # replace this with actual path\n       router_host=\"127.0.0.1\",\n       router_port=9123,\n       endpoint_hosts=[\"127.0.0.1\", \"127.0.0.1\"],\n       endpoint_ports=[9124,9125],\n       endpoint_num_gpus=[2,2],\n       enable_prefix_cache=False,\n       router_type=CustomRouter,\n   )\n"
  },
  {
    "path": "docs/privacy.rst",
    "content": "MLC Chat App Privacy\n====================\n\nMLC Chat run all generation locally.\nAll data stays in users' device and is not collected by the app.\n"
  },
  {
    "path": "docs/requirements.txt",
    "content": "--find-links https://mlc.ai/wheels\nfastapi\nml_dtypes>=0.5.1\nmlc-ai-nightly-cpu\nopenai\nprompt_toolkit\npydantic\nsafetensors\nshortuuid\nsphinx == 5.2.3\nsphinx-reredirects==0.1.2\nsphinx-rtd-theme\nsphinx-tabs == 3.4.1\nsphinx-toolbox == 3.4.0\nsphinxcontrib-napoleon==0.7\nsphinxcontrib_httpdomain==1.8.1\ntiktoken\ntlcpack-sphinx-addon==0.2.2\ntorch\nuvicorn\n"
  },
  {
    "path": "examples/python/microserving/custom_router.py",
    "content": "\"\"\"Microserving customized router example.\"\"\"\n\nfrom typing import Any, AsyncGenerator\n\nimport aiohttp  # pylint: disable=import-error\n\nfrom mlc_llm.interface.router import serve\nfrom mlc_llm.protocol import openai_api_protocol\nfrom mlc_llm.router import Router\nfrom mlc_llm.serve.entrypoints import microserving_entrypoints\n\n\nclass CustomRouter(Router):\n    \"\"\"A customized router class in Microserving.\"\"\"\n\n    async def translate_request(\n        self, request: openai_api_protocol.CompletionRequest, request_id: str\n    ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]:\n        # we will pass request_id as an argument in microserving API calls\n        request.user = request_id\n\n        async with aiohttp.ClientSession(\n            timeout=aiohttp.ClientTimeout(total=3 * 3600), trust_env=True\n        ) as session:\n            decode_start = len(request.prompt) - 1\n            # 1. Ask decode engine to prepare KV entries to receive from prefill engine\n            prep_recv_request = microserving_entrypoints.PrepRecvRequest(\n                **request.model_dump(), end=decode_start\n            )\n            (\n                kv_addr_info,\n                _,\n            ) = await self.send_prepare_receive(\n                session=session,\n                request=prep_recv_request,\n                server_url=self.server_urls[\n                    1\n                ],  # engine 0 is prefill, engine 1 is decode. Here is decode engine\n            )\n            # 2. Ask prefill engine to send KV to decode engine\n            remote_send_request = microserving_entrypoints.RemoteSendRequest(\n                **request.model_dump(),\n                begin=0,\n                end=decode_start,\n                kv_addr_info=kv_addr_info,\n                recv_rank=self.device_id_starts[1],  # the rank of decode engine\n            )\n            await self.send_remote_send(\n                session=session,\n                request=remote_send_request,\n                server_url=self.server_urls[0],  # prefill engine\n            )\n            # 3. Start decoding\n            start_generate_request = microserving_entrypoints.StartGenerateRequest(\n                **request.model_dump(),\n                begin=decode_start,\n            )\n            async for response in self.send_start_generate(\n                session=session,\n                request=start_generate_request,\n                server_url=self.server_urls[1],\n            ):\n                if len(response.choices) > 0:\n                    finish_reason = response.choices[0].finish_reason\n                    if finish_reason == \"preempt\":\n                        yield None\n                yield response\n\n\nserve(\n    model=\"/path/to/model\",  # replace this with actual path\n    model_lib=\"/path/to/model_lib.so\",  # replace this with actual path\n    router_host=\"127.0.0.1\",\n    router_port=9123,\n    endpoint_hosts=[\"127.0.0.1\", \"127.0.0.1\"],\n    endpoint_ports=[9124, 9125],\n    endpoint_num_gpus=[2, 2],\n    enable_prefix_cache=False,\n    router_type=CustomRouter,\n)\n"
  },
  {
    "path": "examples/python/sample_mlc_engine.py",
    "content": "\"\"\"MLC Engine Python example.\"\"\"\n\nfrom mlc_llm import MLCEngine\n\n# Create engine\nmodel = \"HF://mlc-ai/Llama-3-8B-Instruct-q4f16_1-MLC\"  # pylint: disable=invalid-name\nengine = MLCEngine(model)\n\n# Run chat completion in OpenAI API.\nfor response in engine.chat.completions.create(\n    messages=[{\"role\": \"user\", \"content\": \"What is the meaning of life?\"}],\n    model=model,\n    stream=True,\n):\n    for choice in response.choices:\n        print(choice.delta.content, end=\"\", flush=True)\nprint(\"\\n\")\n\nengine.terminate()\n"
  },
  {
    "path": "examples/rest/nodejs/README.MD",
    "content": "# Node/Javascript/Typescript Access Examples for mlc_llm REST APIs\n\nPlease make sure you are running v18.17.x of node (and npm v9.6.7)  --  v20.x currently has some compatibility problems with typescript used in the langchain example.\n\nFirst install dependencies.\n\n`npm i`\n\nCopy `dotenv.exmaple` to `.env`.\n\nTo run JS chat completion (both streaming and non-streaming) example:\n\n`node sample_client.js`\n\nTo run OpenAI (chat completion streaming and non-streaming, and legacy completion) example:\n\n`node sample_openai.js`\n\nTo run LangchainJS Typescript example:\n\n`npm run example`\n"
  },
  {
    "path": "examples/rest/nodejs/dotenv.example",
    "content": "OPENAI_API_KEY=\"none\"\nOPENAI_API_BASE=\"http://127.0.0.1:8000/v1\"\n"
  },
  {
    "path": "examples/rest/nodejs/package.json",
    "content": "{\n  \"name\": \"mlc-llm-js-examples\",\n  \"version\": \"1.0.0\",\n  \"description\": \"\",\n  \"main\": \"index.js\",\n  \"type\": \"module\",\n  \"license\": \"AGPL-version-3.0\",\n  \"private\": false,\n  \"engines\": {\n    \"node\": \">= 14.0.0\",\n    \"npm\": \">= 6.0.0\"\n  },\n  \"homepage\": \"\",\n  \"repository\": {\n    \"type\": \"git\",\n    \"url\": \"\"\n  },\n  \"bugs\": \"\",\n  \"keywords\": [],\n  \"author\": {\n    \"name\": \"\",\n    \"email\": \"\",\n    \"url\": \"\"\n  },\n  \"contributors\": [],\n  \"scripts\": {\n    \"example\": \"ts-node --esm ./sample_langchain.ts\"\n  },\n  \"dependencies\": {\n    \"@types/node\": \"^20.4.4\",\n    \"dotenv\": \"^16.3.1\",\n    \"langchain\": \"^0.0.117\",\n    \"needle\": \"^3.2.0\",\n    \"openai\": \"^3.3.0\",\n    \"typescript\": \"^5.1.6\"\n  },\n  \"devDependencies\": {\n    \"ts-node\": \"^10.9.1\"\n  }\n}\n"
  },
  {
    "path": "examples/rest/nodejs/sample_client.js",
    "content": "import request from 'needle';\n\n( async () => {\nconst color = {\n    PURPLE : '\\x1b[95m',\n    CYAN : '\\x1b[96m',\n    DARKCYAN : '\\x1b[36m',\n    BLUE : '\\x1b[94m',\n    GREEN : '\\x1b[92m',\n    YELLOW : '\\x1b[93m',\n    RED : '\\x1b[91m',\n    BOLD : '\\x1b[1m',\n    UNDERLINE : '\\x1b[4m',\n    END : '\\x1b[0m'\n};\n\nlet payload = {\n    model : 'vicuna-v1-7b',\n    messages: [{\"role\": \"user\", \"content\": \"Write a haiku\"}],\n    stream: false\n};\n\nconst print = ( str ) => {\n    process.stdout.write(str);\n};\n\nconst newline = () => {\n    print('\\n');\n}\n\nnewline();\nprint(color.BOLD + \"Without streaming:\" + color.END);\nnewline();\n\nlet r = await request(\"post\", \"http://127.0.0.1:8000/v1/chat/completions\", payload, {json: true});\n\nprint(color.GREEN + r.body.choices[0].message.content + color.END);\nprint('\\n');\n// Reset the chat\nr = await request(\"post\", \"http://127.0.0.1:8000/v1/chat/completions\", payload, {json: true});\nprint(color.BOLD + \"Reset chat\" + color.END);\nnewline();\n\n// Get a response using a prompt with streaming\n\npayload = {\n    \"model\": \"vicuna-v1-7b\",\n    \"messages\": [{\"role\": \"user\", \"content\": \"Write a haiku\"}],\n    \"stream\": true\n}\n\nprint( color.BOLD + \"With streaming:\" + color.END);\nnewline();\nr =  request.post( \"http://127.0.0.1:8000/v1/chat/completions\", payload, {json: true})\n.on('readable', function() {\n    let jsData = '';\n    let data = '';\n    while (data = this.read()) {\n       const chunk = data.toString().substring(6);\n       if (chunk.trim() === \"[DONE]\")  break;\n       jsData = JSON.parse(chunk);\n       print(color.GREEN + jsData.choices[0].delta.content + color.END);\n    }\n})\n.on('done', async function () {\n    newline();\n    let txtresp = await request(\"get\", \"http://127.0.0.1:8000/stats\");\n    print(color.BOLD + \"Runtime stats:\" + color.END + txtresp.body);\n\n})\n\n})()\n"
  },
  {
    "path": "examples/rest/nodejs/sample_langchain.ts",
    "content": "import { OpenAI } from \"langchain/llms/openai\";\nimport { BufferWindowMemory } from \"langchain/memory\";\nimport { LLMChain } from \"langchain/chains\";\nimport { PromptTemplate } from \"langchain/prompts\";\nimport {TextLoader } from \"langchain/document_loaders/fs/text\";\nimport { loadQAStuffChain } from \"langchain/chains\";\n\nconst color = {\n    PURPLE : '\\x1b[95m',\n    CYAN : '\\x1b[96m',\n    DARKCYAN : '\\x1b[36m',\n    BLUE : '\\x1b[94m',\n    GREEN : '\\x1b[92m',\n    YELLOW : '\\x1b[93m',\n    RED : '\\x1b[91m',\n    BOLD : '\\x1b[1m',\n    UNDERLINE : '\\x1b[4m',\n    END : '\\x1b[0m'\n};\n\nfunction print(str: string) {\n    process.stdout.write(str);\n}\n\nconst newline = () => {\n    print('\\n');\n}\n\n  const chat = new OpenAI( {\n      openAIApiKey: \"empty\",\n      temperature: 0\n    },   {\n        basePath: 'http://127.0.0.1:8000/v1'\n    });\n\n// Conversational LLMChain example\n  const memory = new BufferWindowMemory({ memoryKey: \"history\", k: 1 });\n\n  const template = `The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n\n    Current conversation:\n    {history}\n    Human: {human_input}\n    AI:`;\n\n\n  const prompt = PromptTemplate.fromTemplate(template);\n  let chain = new LLMChain({ llm: chat, prompt, memory });\n\n  let input = \"Write a poem about Pittsburgh.\";\n  print(color.BOLD + input + \"...\" + color.END);\n  newline();\n  let res = await chain.call({ human_input:  input });\n  newline();\n  print(color.GREEN + res.text + color.END);\n  newline();\n  input = \"What does it mean?\";\n  print(color.BOLD + input + \"...\" + color.END);\n  newline();\n  res = await chain.call({ human_input: input });\n  newline();\n  print(color.GREEN + res.text + color.END);\n  newline();\n\n// Question and answer stuff chain example with text loader\nconst loader = new TextLoader('../resources/linux.txt');\nconst documents = await loader.load();\nconst schain =  loadQAStuffChain(chat);\nconst query = \"When was Linux released?\";\nnewline(); newline();\nprint(color.BOLD + \"Query: \" + color.END + color.BLUE + query + color.END);\nnewline();\nconst result = await schain.call({ input_documents: documents,  question: query});\nprint(color.BOLD + \"Response: \" + color.END +  color.GREEN + result.text  + color.END);\n"
  },
  {
    "path": "examples/rest/nodejs/sample_openai.js",
    "content": "import { Configuration, OpenAIApi }  from \"openai\";\nimport dotenv from \"dotenv\";\ndotenv.config();\n\n( async () =>  {\n\nconst configuration = new Configuration({\n    apiKey: process.env.OPENAI_API_KEY,\n    basePath : process.env.OPENAI_API_BASE\n})\nconst openai = new OpenAIApi(configuration);\nlet model = \"vicuna-v1-7b\"\n\nconst color = {\n    PURPLE : '\\x1b[95m',\n    CYAN : '\\x1b[96m',\n    DARKCYAN : '\\x1b[36m',\n    BLUE : '\\x1b[94m',\n    GREEN : '\\x1b[92m',\n    YELLOW : '\\x1b[93m',\n    RED : '\\x1b[91m',\n    BOLD : '\\x1b[1m',\n    UNDERLINE : '\\x1b[4m',\n    END : '\\x1b[0m'\n};\n\nconst print = ( str ) => {\n    process.stdout.write(str);\n};\n\nconst newline = () => {\n    print('\\n');\n}\n\n// Chat completion example without streaming\nnewline();\nprint(color.BOLD + \"OpenAI chat completion example without streaming:\" + color.END);\nnewline();\n\nlet completion = await openai.createChatCompletion({\n  model: model,\n  messages: [{\"role\": \"user\", \"content\": \"Write a poem about OpenAI\"}]\n});\n\n\nprint(color.GREEN + completion.data.choices[0].message.content + color.END)\nnewline();  newline();\n\n\n// Chat completion example with streaming\n// (raw implementation since npm module does not support it yet - it will have support in upcoming 4.x)\n\nprint(color.BOLD + \"OpenAI chat completion example with streaming:\" + color.END);\nnewline();\ncompletion = await openai.createChatCompletion({\n    model: model,\n    messages: [{\"role\": \"user\", \"content\": \"Write a poem about OpenAI\"}],\n    stream: true,\n}, {responseType: 'stream'});\n\ncompletion.data.on('data', async (data) => {\n        const parsed = JSON.parse(data.toString().substring(6));\n        print(color.GREEN + parsed.choices[0].delta.content + color.END);\n});\n\ncompletion.data.on('close', async ()  => {\n    newline(); newline();\n\n    // Completion example\n    print(color.BOLD + \"OpenAI completion example:\" + color.END)\n    newline();\n    let res = await openai.createCompletion({ prompt: \"Write a poem about OpenAI\", model: model});\n    print(color.GREEN + res.data.choices[0].text + color.END);\n    newline();  newline();\n\n    });\n})()\n"
  },
  {
    "path": "examples/rest/nodejs/tsconfig.json",
    "content": "{\n  \"compilerOptions\": {\n    \"target\": \"es2020\",                                  /* Set the JavaScript language version for emitted JavaScript and include compatible library declarations. */\n    \"lib\": [\"es2020\"],                                   /* Specify a set of bundled library declaration files that describe the target runtime environment. */\n    \"module\": \"nodenext\",                                /* Specify what module code is generated. */\n    \"rootDir\": \"src\",                                    /* Specify the root folder within your source files. */\n    \"outDir\": \"./dist\",                                  /* Specify an output folder for all emitted files. */\n    \"esModuleInterop\": true,                             /* Emit additional JavaScript to ease support for importing CommonJS modules. This enables 'allowSyntheticDefaultImports' for type compatibility. */\n    \"forceConsistentCasingInFileNames\": true,            /* Ensure that casing is correct in imports. */\n    \"strict\": true,                                      /* Enable all strict type-checking options. */\n    \"noImplicitAny\": true,                               /* Enable error reporting for expressions and declarations with an implied 'any' type. */\n    \"skipLibCheck\": true                                 /* Skip type checking all .d.ts files. */\n  }\n}\n"
  },
  {
    "path": "examples/rest/python/sample_client.py",
    "content": "import json\n\nimport requests\n\n\nclass color:\n    PURPLE = \"\\033[95m\"\n    CYAN = \"\\033[96m\"\n    DARKCYAN = \"\\033[36m\"\n    BLUE = \"\\033[94m\"\n    GREEN = \"\\033[92m\"\n    YELLOW = \"\\033[93m\"\n    RED = \"\\033[91m\"\n    BOLD = \"\\033[1m\"\n    UNDERLINE = \"\\033[4m\"\n    END = \"\\033[0m\"\n\n\n# Get a response using a prompt without streaming\npayload = {\n    \"model\": \"vicuna-v1-7b\",\n    \"messages\": [{\"role\": \"user\", \"content\": \"Write a haiku\"}],\n    \"stream\": False,\n}\nr = requests.post(\"http://127.0.0.1:8000/v1/chat/completions\", json=payload)\nprint(\n    f\"{color.BOLD}Without streaming:{color.END}\\n{color.GREEN}{r.json()['choices'][0]['message']['content']}{color.END}\\n\"\n)\n\n# Reset the chat\nr = requests.post(\"http://127.0.0.1:8000/chat/reset\", json=payload)\nprint(f\"{color.BOLD}Reset chat:{color.END} {str(r)}\\n\")\n\n# Get a response using a prompt with streaming\npayload = {\n    \"model\": \"vicuna-v1-7b\",\n    \"messages\": [{\"role\": \"user\", \"content\": \"Write a haiku\"}],\n    \"stream\": True,\n}\nwith requests.post(\"http://127.0.0.1:8000/v1/chat/completions\", json=payload, stream=True) as r:\n    print(f\"{color.BOLD}With streaming:{color.END}\")\n    for chunk in r:\n        if chunk[6:].decode(\"utf-8\").strip() == \"[DONE]\":\n            break\n        content = json.loads(chunk[6:])[\"choices\"][0][\"delta\"].get(\"content\", \"\")\n        print(f\"{color.GREEN}{content}{color.END}\", end=\"\", flush=True)\n    print(\"\\n\")\n\n# Get the latest runtime stats\nr = requests.get(\"http://127.0.0.1:8000/stats\")\nprint(f\"{color.BOLD}Runtime stats:{color.END} {r.json()}\\n\")\n"
  },
  {
    "path": "examples/rest/python/sample_langchain.py",
    "content": "from langchain import LLMChain, PromptTemplate\nfrom langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\nfrom langchain.chains import RetrievalQA\nfrom langchain.chains.question_answering import load_qa_chain\nfrom langchain.chat_models import ChatOpenAI\nfrom langchain.document_loaders import (\n    DirectoryLoader,\n    TextLoader,\n    UnstructuredRSTLoader,\n)\nfrom langchain.llms import OpenAI\nfrom langchain.memory import ConversationBufferWindowMemory\nfrom langchain.text_splitter import CharacterTextSplitter\nfrom langchain.vectorstores import Chroma\n\n# Note that Langchain support for embedding documents using MLC is currently blocked on\n# https://github.com/langchain-ai/langchain/pull/7815\n# We have subclassed `OpenAIEmbeddings` in the meantime to get around this dependency.\nfrom mlc_llm.contrib.embeddings.openai import MLCEmbeddings\n\n# First set the following in your environment:\n# export OPENAI_API_BASE=http://127.0.0.1:8000/v1\n# export OPENAI_API_KEY=EMPTY\n\n# Note that Langchain does not currently support Pydantic v2:\n# https://github.com/langchain-ai/langchain/issues/6841\n# Please ensure that your `pydantic` version is < 2.0\n\n\nclass color:\n    PURPLE = \"\\033[95m\"\n    CYAN = \"\\033[96m\"\n    DARKCYAN = \"\\033[36m\"\n    BLUE = \"\\033[94m\"\n    GREEN = \"\\033[92m\"\n    YELLOW = \"\\033[93m\"\n    RED = \"\\033[91m\"\n    BOLD = \"\\033[1m\"\n    UNDERLINE = \"\\033[4m\"\n    END = \"\\033[0m\"\n\n\ndef llm_chain_example():\n    template = \"\"\"\n    {history}\n    USER: {human_input}\n    ASSISTANT:\"\"\"\n\n    prompt = PromptTemplate(input_variables=[\"history\", \"human_input\"], template=template)\n\n    llm_chain = LLMChain(\n        llm=ChatOpenAI(streaming=True, callbacks=[StreamingStdOutCallbackHandler()]),\n        prompt=prompt,\n        verbose=True,\n        memory=ConversationBufferWindowMemory(human_prefix=\"USER\", ai_prefix=\"ASSISTANT\"),\n    )\n\n    output = llm_chain.predict(human_input=\"Write a short poem about Pittsburgh.\")\n    output = llm_chain.predict(human_input=\"What does the poem mean?\")\n\n\ndef load_qa_chain_example():\n    loader = TextLoader(\"../resources/linux.txt\")\n    documents = loader.load()\n    chain = load_qa_chain(llm=OpenAI(), chain_type=\"stuff\", verbose=False)\n    query = \"When was Linux released?\"\n    print(f\"{color.BOLD}Query:{color.END} {color.BLUE} {query}{color.END}\")\n    print(\n        f\"{color.BOLD}Response:{color.END} {color.GREEN}{chain.run(input_documents=documents, question=query)}{color.END}\"\n    )\n\n\ndef retrieval_qa_sotu_example():\n    prompt_template = \"\"\"Use only the following pieces of context to answer the question at the end. Don't use any other knowledge.\n\n    {context}\n\n    USER: {question}\n    ASSISTANT:\"\"\"\n\n    PROMPT = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n\n    loader = TextLoader(\"../resources/state_of_the_union.txt\")\n    documents = loader.load()\n\n    text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=100)\n    texts = text_splitter.split_documents(documents)\n    # print(texts)\n    embeddings = MLCEmbeddings(deployment=\"text-embedding-ada-002\", embedding_ctx_length=None)\n    db = Chroma.from_documents(documents=texts, embedding=embeddings)\n    retriever = db.as_retriever(search_type=\"similarity\", search_kwargs={\"k\": 2})\n    qa = RetrievalQA.from_chain_type(\n        llm=OpenAI(),\n        chain_type=\"stuff\",\n        retriever=retriever,\n        return_source_documents=True,\n        chain_type_kwargs={\"prompt\": PROMPT},\n    )\n    questions = [\n        \"What is the American Rescue Plan?\",\n        \"What did the president say about Ketanji Brown Jackson?\",\n        \"Who is mentioned in the speech?\",\n        \"To whom is the speech addressed?\",\n        \"Tell me more about the Made in America campaign.\",\n    ]\n\n    for qn in questions:\n        print(f\"{color.BOLD}QUESTION:{color.END} {qn}\")\n        res = qa({\"query\": qn})\n        print(f\"{color.BOLD}RESPONSE:{color.END} {color.GREEN}{res['result']}{color.END}\")\n        print(\n            f\"{color.BOLD}SOURCE:{color.END} {color.BLUE}{repr(res['source_documents'][0].page_content)}{color.END}\"\n        )\n        print()\n\n\ndef retrieval_qa_mlc_docs_example():\n    prompt_template = \"\"\"Use only the following pieces of context to answer the question at the end. Don't use any other knowledge.\n\n    {context}\n\n    USER: {question}\n    ASSISTANT:\"\"\"\n\n    PROMPT = PromptTemplate(template=prompt_template, input_variables=[\"context\", \"question\"])\n\n    loader = DirectoryLoader(\n        \"../../../docs\",\n        glob=\"*/*.rst\",\n        show_progress=True,\n        loader_cls=UnstructuredRSTLoader,\n        loader_kwargs={\"mode\": \"single\"},\n    )\n    documents = loader.load()\n    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)\n    texts = text_splitter.split_documents(documents)\n    embeddings = MLCEmbeddings(deployment=\"text-embedding-ada-002\", embedding_ctx_length=None)\n    db = Chroma.from_documents(collection_name=\"abc\", documents=texts, embedding=embeddings)\n    retriever = db.as_retriever(search_type=\"similarity\", search_kwargs={\"k\": 3})\n    qa = RetrievalQA.from_chain_type(\n        llm=OpenAI(),\n        chain_type=\"stuff\",\n        retriever=retriever,\n        return_source_documents=True,\n        chain_type_kwargs={\"prompt\": PROMPT},\n    )\n    while True:\n        qn = input(f\"{color.BOLD}QUESTION:{color.END} \")\n        res = qa({\"query\": qn})\n        print(f\"{color.BOLD}RESPONSE:{color.END} {color.GREEN}{res['result']}{color.END}\")\n        print(\n            f\"{color.BOLD}SOURCE:{color.END} {color.BLUE}{repr(res['source_documents'][0].page_content)}{color.END}\"\n        )\n        print()\n\n    # Some example questions:\n    # - What is the chat config?\n    # - What is temperature?\n    # - What are the REST API endpoints?\n    # - What are the available quantization options?\n\n\n# Uncomment one of the following lines to try out the corresponding demo:\n\n# llm_chain_example()\n# load_qa_chain_example()\n# retrieval_qa_sotu_example()\n# retrieval_qa_mlc_docs_example()\n"
  },
  {
    "path": "examples/rest/python/sample_openai.py",
    "content": "import openai\n\nopenai.api_key = \"None\"\nopenai.api_base = \"http://127.0.0.1:8000/v1\"\n\nmodel = \"vicuna-v1-7b\"\n\n\nclass color:\n    PURPLE = \"\\033[95m\"\n    CYAN = \"\\033[96m\"\n    DARKCYAN = \"\\033[36m\"\n    BLUE = \"\\033[94m\"\n    GREEN = \"\\033[92m\"\n    YELLOW = \"\\033[93m\"\n    RED = \"\\033[91m\"\n    BOLD = \"\\033[1m\"\n    UNDERLINE = \"\\033[4m\"\n    END = \"\\033[0m\"\n\n\n# Chat completion example without streaming\nprint(f\"{color.BOLD}OpenAI chat completion example without streaming:{color.END}\\n\")\ncompletion = openai.ChatCompletion.create(\n    model=model, messages=[{\"role\": \"user\", \"content\": \"Write a poem about OpenAI\"}]\n)\nprint(f\"{color.GREEN}{completion.choices[0].message.content}{color.END}\\n\\n\")\n\n# Chat completion example with streaming\nprint(f\"{color.BOLD}OpenAI chat completion example with streaming:{color.END}\\n\")\nres = openai.ChatCompletion.create(\n    model=model, messages=[{\"role\": \"user\", \"content\": \"Write a poem about OpenAI\"}], stream=True\n)\nfor chunk in res:\n    content = chunk[\"choices\"][0][\"delta\"].get(\"content\", \"\")\n    print(f\"{color.GREEN}{content}{color.END}\", end=\"\", flush=True)\nprint(\"\\n\")\n\n# Completion example\nprint(f\"{color.BOLD}OpenAI completion example:{color.END}\\n\")\nres = openai.Completion.create(prompt=\"Write a poem about OpenAI\", model=model)\nprint(f\"{color.GREEN}{res.choices[0].text}{color.END}\\n\\n\")\n"
  },
  {
    "path": "examples/rest/resources/linux.txt",
    "content": "Linux is a family of open-source Unix-like operating systems based on the Linux kernel, an operating system kernel first released on September 17, 1991, by Linus Torvalds. Linux is typically packaged as a Linux distribution, which includes the kernel and supporting system software and libraries, many of which are provided by the GNU Project. Many Linux distributions use the word \"Linux\" in their name, but the Free Software Foundation uses the name \"GNU/Linux\" to emphasize the importance of GNU software, causing some controversy.\n\nPopular Linux distributions include Debian, Fedora Linux, and Ubuntu, the latter of which itself consists of many different distributions and modifications, including Lubuntu and Xubuntu. Commercial distributions include Red Hat Enterprise Linux and SUSE Linux Enterprise. Desktop Linux distributions include a windowing system such as X11 or Wayland, and a desktop environment such as GNOME or KDE Plasma. Distributions intended for servers may omit graphics altogether, or include a solution stack such as LAMP. Because Linux is freely redistributable, anyone may create a distribution for any purpose.\n\nLinux was originally developed for personal computers based on the Intel x86 architecture, but has since been ported to more platforms than any other operating system. Because of the dominance of the Linux-based Android on smartphones, Linux, including Android, has the largest installed base of all general-purpose operating systems, as of May 2022. Although Linux is, as of November 2022, used by only around 2.6 percent of desktop computers, the Chromebook, which runs the Linux kernel-based ChromeOS, dominates the US K–12 education market and represents nearly 20 percent of sub-$300 notebook sales in the US. Linux is the leading operating system on servers (over 96.4% of the top 1 million web servers' operating systems are Linux), leads other big iron systems such as mainframe computers, and is used on all of the world's 500 fastest supercomputers (since November 2017, having gradually displaced all competitors).\n\nLinux also runs on embedded systems, i.e. devices whose operating system is typically built into the firmware and is highly tailored to the system. This includes routers, automation controls, smart home devices, video game consoles, televisions (Samsung and LG Smart TVs), automobiles (Tesla, Audi, Mercedes-Benz, Hyundai and Toyota), and spacecraft (Falcon 9 rocket, Dragon crew capsule and the Perseverance rover).\n\nLinux is one of the most prominent examples of free and open-source software collaboration. The source code may be used, modified and distributed commercially or non-commercially by anyone under the terms of its respective licenses, such as the GNU General Public License (GPL). The Linux kernel, for example, is licensed under the GPLv2, with an exception for system calls that allows code that calls the kernel via system calls not to be licensed under the GPL.\n\nThe Unix operating system was conceived and implemented in 1969, at AT&T's Bell Labs, in the United States by Ken Thompson, Dennis Ritchie, Douglas McIlroy, and Joe Ossanna. First released in 1971, Unix was written entirely in assembly language, as was common practice at the time. In 1973, in a key pioneering approach, it was rewritten in the C programming language by Dennis Ritchie (with the exception of some hardware and I/O routines). The availability of a high-level language implementation of Unix made its porting to different computer platforms easier.\n\nDue to an earlier antitrust case forbidding it from entering the computer business, AT&T licensed the operating system's source code as a trade secret to anyone who asked. As a result, Unix grew quickly and became widely adopted by academic institutions and businesses. In 1984, AT&T divested itself of its regional operating companies, and was released from its obligation not to enter the computer business; freed of that obligation, Bell Labs began selling Unix as a proprietary product, where users were not legally allowed to modify it.\n\nOnyx Systems began selling early microcomputer-based Unix workstations in 1980. Later, Sun Microsystems, founded as a spin-off of a student project at Stanford University, also began selling Unix-based desktop workstations in 1982. While Sun workstations did not utilize commodity PC hardware, for which Linux was later originally developed, it represented the first successful commercial attempt at distributing a primarily single-user microcomputer that ran a Unix operating system.\n\nWith Unix increasingly \"locked in\" as a proprietary product, the GNU Project, started in 1983 by Richard Stallman, had the goal of creating a \"complete Unix-compatible software system\" composed entirely of free software. Work began in 1984. Later, in 1985, Stallman started the Free Software Foundation and wrote the GNU General Public License (GNU GPL) in 1989. By the early 1990s, many of the programs required in an operating system (such as libraries, compilers, text editors, a command-line shell, and a windowing system) were completed, although low-level elements such as device drivers, daemons, and the kernel, called GNU Hurd, were stalled and incomplete.\n\nMINIX was created by Andrew S. Tanenbaum, a computer science professor, and released in 1987 as a minimal Unix-like operating system targeted at students and others who wanted to learn operating system principles. Although the complete source code of MINIX was freely available, the licensing terms prevented it from being free software until the licensing changed in April 2000.\n\nAlthough not released until 1992, due to legal complications, development of 386BSD, from which NetBSD, OpenBSD and FreeBSD descended, predated that of Linux.\n\nLinus Torvalds has stated on separate occasions that if the GNU kernel or 386BSD had been available at the time (1991), he probably would not have created Linux.\n"
  },
  {
    "path": "examples/rest/resources/state_of_the_union.txt",
    "content": "Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans.\n\nLast year COVID-19 kept us apart. This year we are finally together again.\n\nTonight, we meet as Democrats Republicans and Independents. But most importantly as Americans.\n\nWith a duty to one another to the American people to the Constitution.\n\nAnd with an unwavering resolve that freedom will always triumph over tyranny.\n\nSix days ago, Russia’s Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated.\n\nHe thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined.\n\nHe met the Ukrainian people.\n\nFrom President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world.\n\nGroups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland.\n\nIn this struggle as President Zelenskyy said in his speech to the European Parliament “Light will win over darkness.” The Ukrainian Ambassador to the United States is here tonight.\n\nLet each of us here tonight in this Chamber send an unmistakable signal to Ukraine and to the world.\n\nPlease rise if you are able and show that, Yes, we the United States of America stand with the Ukrainian people.\n\nThroughout our history we’ve learned this lesson when dictators do not pay a price for their aggression they cause more chaos.\n\nThey keep moving.\n\nAnd the costs and the threats to America and the world keep rising.\n\nThat’s why the NATO Alliance was created to secure peace and stability in Europe after World War 2.\n\nThe United States is a member along with 29 other nations.\n\nIt matters. American diplomacy matters. American resolve matters.\n\nPutin’s latest attack on Ukraine was premeditated and unprovoked.\n\nHe rejected repeated efforts at diplomacy.\n\nHe thought the West and NATO wouldn’t respond. And he thought he could divide us at home. Putin was wrong. We were ready.  Here is what we did.\n\nWe prepared extensively and carefully.\n\nWe spent months building a coalition of other freedom-loving nations from Europe and the Americas to Asia and Africa to confront Putin.\n\nI spent countless hours unifying our European allies. We shared with the world in advance what we knew Putin was planning and precisely how he would try to falsely justify his aggression.\n\nWe countered Russia’s lies with truth.\n\nAnd now that he has acted the free world is holding him accountable.\n\nAlong with twenty-seven members of the European Union including France, Germany, Italy, as well as countries like the United Kingdom, Canada, Japan, Korea, Australia, New Zealand, and many others, even Switzerland.\n\nWe are inflicting pain on Russia and supporting the people of Ukraine. Putin is now isolated from the world more than ever.\n\nTogether with our allies –we are right now enforcing powerful economic sanctions.\n\nWe are cutting off Russia’s largest banks from the international financial system.\n\nPreventing Russia’s central bank from defending the Russian Ruble making Putin’s $630 Billion “war fund” worthless.\n\nWe are choking off Russia’s access to technology that will sap its economic strength and weaken its military for years to come.\n\nTonight I say to the Russian oligarchs and corrupt leaders who have bilked billions of dollars off this violent regime no more.\n\nThe U.S. Department of Justice is assembling a dedicated task force to go after the crimes of Russian oligarchs.\n\nWe are joining with our European allies to find and seize your yachts your luxury apartments your private jets. We are coming for your ill-begotten gains.\n\nAnd tonight I am announcing that we will join our allies in closing off American air space to all Russian flights – further isolating Russia – and adding an additional squeeze –on their economy. The Ruble has lost 30% of its value.\n\nThe Russian stock market has lost 40% of its value and trading remains suspended. Russia’s economy is reeling and Putin alone is to blame.\n\nTogether with our allies we are providing support to the Ukrainians in their fight for freedom. Military assistance. Economic assistance. Humanitarian assistance.\n\nWe are giving more than $1 Billion in direct assistance to Ukraine.\n\nAnd we will continue to aid the Ukrainian people as they defend their country and to help ease their suffering.\n\nLet me be clear, our forces are not engaged and will not engage in conflict with Russian forces in Ukraine.\n\nOur forces are not going to Europe to fight in Ukraine, but to defend our NATO Allies – in the event that Putin decides to keep moving west.\n\nFor that purpose we’ve mobilized American ground forces, air squadrons, and ship deployments to protect NATO countries including Poland, Romania, Latvia, Lithuania, and Estonia.\n\nAs I have made crystal clear the United States and our Allies will defend every inch of territory of NATO countries with the full force of our collective power.\n\nAnd we remain clear-eyed. The Ukrainians are fighting back with pure courage. But the next few days weeks, months, will be hard on them.\n\nPutin has unleashed violence and chaos.  But while he may make gains on the battlefield – he will pay a continuing high price over the long run.\n\nAnd a proud Ukrainian people, who have known 30 years  of independence, have repeatedly shown that they will not tolerate anyone who tries to take their country backwards.\n\nTo all Americans, I will be honest with you, as I’ve always promised. A Russian dictator, invading a foreign country, has costs around the world.\n\nAnd I’m taking robust action to make sure the pain of our sanctions  is targeted at Russia’s economy. And I will use every tool at our disposal to protect American businesses and consumers.\n\nTonight, I can announce that the United States has worked with 30 other countries to release 60 Million barrels of oil from reserves around the world.\n\nAmerica will lead that effort, releasing 30 Million barrels from our own Strategic Petroleum Reserve. And we stand ready to do more if necessary, unified with our allies.\n\nThese steps will help blunt gas prices here at home. And I know the news about what’s happening can seem alarming.\n\nBut I want you to know that we are going to be okay.\n\nWhen the history of this era is written Putin’s war on Ukraine will have left Russia weaker and the rest of the world stronger.\n\nWhile it shouldn’t have taken something so terrible for people around the world to see what’s at stake now everyone sees it clearly.\n\nWe see the unity among leaders of nations and a more unified Europe a more unified West. And we see unity among the people who are gathering in cities in large crowds around the world even in Russia to demonstrate their support for Ukraine.\n\nIn the battle between democracy and autocracy, democracies are rising to the moment, and the world is clearly choosing the side of peace and security.\n\nThis is a real test. It’s going to take time. So let us continue to draw inspiration from the iron will of the Ukrainian people.\n\nTo our fellow Ukrainian Americans who forge a deep bond that connects our two nations we stand with you.\n\nPutin may circle Kyiv with tanks, but he will never gain the hearts and souls of the Ukrainian people.\n\nHe will never extinguish their love of freedom. He will never weaken the resolve of the free world.\n\nWe meet tonight in an America that has lived through two of the hardest years this nation has ever faced.\n\nThe pandemic has been punishing.\n\nAnd so many families are living paycheck to paycheck, struggling to keep up with the rising cost of food, gas, housing, and so much more.\n\nI understand.\n\nI remember when my Dad had to leave our home in Scranton, Pennsylvania to find work. I grew up in a family where if the price of food went up, you felt it.\n\nThat’s why one of the first things I did as President was fight to pass the American Rescue Plan.\n\nBecause people were hurting. We needed to act, and we did.\n\nFew pieces of legislation have done more in a critical moment in our history to lift us out of crisis.\n\nIt fueled our efforts to vaccinate the nation and combat COVID-19. It delivered immediate economic relief for tens of millions of Americans.\n\nHelped put food on their table, keep a roof over their heads, and cut the cost of health insurance.\n\nAnd as my Dad used to say, it gave people a little breathing room.\n\nAnd unlike the $2 Trillion tax cut passed in the previous administration that benefitted the top 1% of Americans, the American Rescue Plan helped working people—and left no one behind.\n\nAnd it worked. It created jobs. Lots of jobs.\n\nIn fact—our economy created over 6.5 Million new jobs just last year, more jobs created in one year\nthan ever before in the history of America.\n\nOur economy grew at a rate of 5.7% last year, the strongest growth in nearly 40 years, the first step in bringing fundamental change to an economy that hasn’t worked for the working people of this nation for too long.\n\nFor the past 40 years we were told that if we gave tax breaks to those at the very top, the benefits would trickle down to everyone else.\n\nBut that trickle-down theory led to weaker economic growth, lower wages, bigger deficits, and the widest gap between those at the top and everyone else in nearly a century.\n\nVice President Harris and I ran for office with a new economic vision for America.\n\nInvest in America. Educate Americans. Grow the workforce. Build the economy from the bottom up\nand the middle out, not from the top down.\n\nBecause we know that when the middle class grows, the poor have a ladder up and the wealthy do very well.\n\nAmerica used to have the best roads, bridges, and airports on Earth.\n\nNow our infrastructure is ranked 13th in the world.\n\nWe won’t be able to compete for the jobs of the 21st Century if we don’t fix that.\n\nThat’s why it was so important to pass the Bipartisan Infrastructure Law—the most sweeping investment to rebuild America in history.\n\nThis was a bipartisan effort, and I want to thank the members of both parties who worked to make it happen.\n\nWe’re done talking about infrastructure weeks.\n\nWe’re going to have an infrastructure decade.\n\nIt is going to transform America and put us on a path to win the economic competition of the 21st Century that we face with the rest of the world—particularly with China.\n\nAs I’ve told Xi Jinping, it is never a good bet to bet against the American people.\n\nWe’ll create good jobs for millions of Americans, modernizing roads, airports, ports, and waterways all across America.\n\nAnd we’ll do it all to withstand the devastating effects of the climate crisis and promote environmental justice.\n\nWe’ll build a national network of 500,000 electric vehicle charging stations, begin to replace poisonous lead pipes—so every child—and every American—has clean water to drink at home and at school, provide affordable high-speed internet for every American—urban, suburban, rural, and tribal communities.\n\n4,000 projects have already been announced.\n\nAnd tonight, I’m announcing that this year we will start fixing over 65,000 miles of highway and 1,500 bridges in disrepair.\n\nWhen we use taxpayer dollars to rebuild America – we are going to Buy American: buy American products to support American jobs.\n\nThe federal government spends about $600 Billion a year to keep the country safe and secure.\n\nThere’s been a law on the books for almost a century\nto make sure taxpayers’ dollars support American jobs and businesses.\n\nEvery Administration says they’ll do it, but we are actually doing it.\n\nWe will buy American to make sure everything from the deck of an aircraft carrier to the steel on highway guardrails are made in America.\n\nBut to compete for the best jobs of the future, we also need to level the playing field with China and other competitors.\n\nThat’s why it is so important to pass the Bipartisan Innovation Act sitting in Congress that will make record investments in emerging technologies and American manufacturing.\n\nLet me give you one example of why it’s so important to pass it.\n\nIf you travel 20 miles east of Columbus, Ohio, you’ll find 1,000 empty acres of land.\n\nIt won’t look like much, but if you stop and look closely, you’ll see a “Field of dreams,” the ground on which America’s future will be built.\n\nThis is where Intel, the American company that helped build Silicon Valley, is going to build its $20 billion semiconductor “mega site”.\n\nUp to eight state-of-the-art factories in one place. 10,000 new good-paying jobs.\n\nSome of the most sophisticated manufacturing in the world to make computer chips the size of a fingertip that power the world and our everyday lives.\n\nSmartphones. The Internet. Technology we have yet to invent.\n\nBut that’s just the beginning.\n\nIntel’s CEO, Pat Gelsinger, who is here tonight, told me they are ready to increase their investment from\n$20 billion to $100 billion.\n\nThat would be one of the biggest investments in manufacturing in American history.\n\nAnd all they’re waiting for is for you to pass this bill.\n\nSo let’s not wait any longer. Send it to my desk. I’ll sign it.\n\nAnd we will really take off.\n\nAnd Intel is not alone.\n\nThere’s something happening in America.\n\nJust look around and you’ll see an amazing story.\n\nThe rebirth of the pride that comes from stamping products “Made In America.” The revitalization of American manufacturing.\n\nCompanies are choosing to build new factories here, when just a few years ago, they would have built them overseas.\n\nThat’s what is happening. Ford is investing $11 billion to build electric vehicles, creating 11,000 jobs across the country.\n\nGM is making the largest investment in its history—$7 billion to build electric vehicles, creating 4,000 jobs in Michigan.\n\nAll told, we created 369,000 new manufacturing jobs in America just last year.\n\nPowered by people I’ve met like JoJo Burgess, from generations of union steelworkers from Pittsburgh, who’s here with us tonight.\n\nAs Ohio Senator Sherrod Brown says, “It’s time to bury the label “Rust Belt.”\n\nIt’s time.\n\nBut with all the bright spots in our economy, record job growth and higher wages, too many families are struggling to keep up with the bills.\n\nInflation is robbing them of the gains they might otherwise feel.\n\nI get it. That’s why my top priority is getting prices under control.\n\nLook, our economy roared back faster than most predicted, but the pandemic meant that businesses had a hard time hiring enough workers to keep up production in their factories.\n\nThe pandemic also disrupted global supply chains.\n\nWhen factories close, it takes longer to make goods and get them from the warehouse to the store, and prices go up.\n\nLook at cars.\n\nLast year, there weren’t enough semiconductors to make all the cars that people wanted to buy.\n\nAnd guess what, prices of automobiles went up.\n\nSo—we have a choice.\n\nOne way to fight inflation is to drive down wages and make Americans poorer.\n\nI have a better plan to fight inflation.\n\nLower your costs, not your wages.\n\nMake more cars and semiconductors in America.\n\nMore infrastructure and innovation in America.\n\nMore goods moving faster and cheaper in America.\n\nMore jobs where you can earn a good living in America.\n\nAnd instead of relying on foreign supply chains, let’s make it in America.\n\nEconomists call it “increasing the productive capacity of our economy.”\n\nI call it building a better America.\n\nMy plan to fight inflation will lower your costs and lower the deficit.\n\n17 Nobel laureates in economics say my plan will ease long-term inflationary pressures. Top business leaders and most Americans support my plan. And here’s the plan:\n\nFirst – cut the cost of prescription drugs. Just look at insulin. One in ten Americans has diabetes. In Virginia, I met a 13-year-old boy named Joshua Davis.\n\nHe and his Dad both have Type 1 diabetes, which means they need insulin every day. Insulin costs about $10 a vial to make.\n\nBut drug companies charge families like Joshua and his Dad up to 30 times more. I spoke with Joshua’s mom.\n\nImagine what it’s like to look at your child who needs insulin and have no idea how you’re going to pay for it.\n\nWhat it does to your dignity, your ability to look your child in the eye, to be the parent you expect to be.\n\nJoshua is here with us tonight. Yesterday was his birthday. Happy birthday, buddy.\n\nFor Joshua, and for the 200,000 other young people with Type 1 diabetes, let’s cap the cost of insulin at $35 a month so everyone can afford it.\n\nDrug companies will still do very well. And while we’re at it let Medicare negotiate lower prices for prescription drugs, like the VA already does.\n\nLook, the American Rescue Plan is helping millions of families on Affordable Care Act plans save $2,400 a year on their health care premiums. Let’s close the coverage gap and make those savings permanent.\n\nSecond – cut energy costs for families an average of $500 a year by combatting climate change.\n\nLet’s provide investments and tax credits to weatherize your homes and businesses to be energy efficient and you get a tax credit; double America’s clean energy production in solar, wind, and so much more;  lower the price of electric vehicles, saving you another $80 a month because you’ll never have to pay at the gas pump again.\n\nThird – cut the cost of child care. Many families pay up to $14,000 a year for child care per child.\n\nMiddle-class and working families shouldn’t have to pay more than 7% of their income for care of young children.\n\nMy plan will cut the cost in half for most families and help parents, including millions of women, who left the workforce during the pandemic because they couldn’t afford child care, to be able to get back to work.\n\nMy plan doesn’t stop there. It also includes home and long-term care. More affordable housing. And Pre-K for every 3- and 4-year-old.\n\nAll of these will lower costs.\n\nAnd under my plan, nobody earning less than $400,000 a year will pay an additional penny in new taxes. Nobody.\n\nThe one thing all Americans agree on is that the tax system is not fair. We have to fix it.\n\nI’m not looking to punish anyone. But let’s make sure corporations and the wealthiest Americans start paying their fair share.\n\nJust last year, 55 Fortune 500 corporations earned $40 billion in profits and paid zero dollars in federal income tax.\n\nThat’s simply not fair. That’s why I’ve proposed a 15% minimum tax rate for corporations.\n\nWe got more than 130 countries to agree on a global minimum tax rate so companies can’t get out of paying their taxes at home by shipping jobs and factories overseas.\n\nThat’s why I’ve proposed closing loopholes so the very wealthy don’t pay a lower tax rate than a teacher or a firefighter.\n\nSo that’s my plan. It will grow the economy and lower costs for families.\n\nSo what are we waiting for? Let’s get this done. And while you’re at it, confirm my nominees to the Federal Reserve, which plays a critical role in fighting inflation.\n\nMy plan will not only lower costs to give families a fair shot, it will lower the deficit.\n\nThe previous Administration not only ballooned the deficit with tax cuts for the very wealthy and corporations, it undermined the watchdogs whose job was to keep pandemic relief funds from being wasted.\n\nBut in my administration, the watchdogs have been welcomed back.\n\nWe’re going after the criminals who stole billions in relief money meant for small businesses and millions of Americans.\n\nAnd tonight, I’m announcing that the Justice Department will name a chief prosecutor for pandemic fraud.\n\nBy the end of this year, the deficit will be down to less than half what it was before I took office.\n\nThe only president ever to cut the deficit by more than one trillion dollars in a single year.\n\nLowering your costs also means demanding more competition.\n\nI’m a capitalist, but capitalism without competition isn’t capitalism.\n\nIt’s exploitation—and it drives up prices.\n\nWhen corporations don’t have to compete, their profits go up, your prices go up, and small businesses and family farmers and ranchers go under.\n\nWe see it happening with ocean carriers moving goods in and out of America.\n\nDuring the pandemic, these foreign-owned companies raised prices by as much as 1,000% and made record profits.\n\nTonight, I’m announcing a crackdown on these companies overcharging American businesses and consumers.\n\nAnd as Wall Street firms take over more nursing homes, quality in those homes has gone down and costs have gone up.\n\nThat ends on my watch.\n\nMedicare is going to set higher standards for nursing homes and make sure your loved ones get the care they deserve and expect.\n\nWe’ll also cut costs and keep the economy going strong by giving workers a fair shot, provide more training and apprenticeships, hire them based on their skills not degrees.\n\nLet’s pass the Paycheck Fairness Act and paid leave.\n\nRaise the minimum wage to $15 an hour and extend the Child Tax Credit, so no one has to raise a family in poverty.\n\nLet’s increase Pell Grants and increase our historic support of HBCUs, and invest in what Jill—our First Lady who teaches full-time—calls America’s best-kept secret: community colleges.\n\nAnd let’s pass the PRO Act when a majority of workers want to form a union—they shouldn’t be stopped.\n\nWhen we invest in our workers, when we build the economy from the bottom up and the middle out together, we can do something we haven’t done in a long time: build a better America.\n\nFor more than two years, COVID-19 has impacted every decision in our lives and the life of the nation.\n\nAnd I know you’re tired, frustrated, and exhausted.\n\nBut I also know this.\n\nBecause of the progress we’ve made, because of your resilience and the tools we have, tonight I can say\nwe are moving forward safely, back to more normal routines.\n\nWe’ve reached a new moment in the fight against COVID-19, with severe cases down to a level not seen since last July.\n\nJust a few days ago, the Centers for Disease Control and Prevention—the CDC—issued new mask guidelines.\n\nUnder these new guidelines, most Americans in most of the country can now be mask free.\n\nAnd based on the projections, more of the country will reach that point across the next couple of weeks.\n\nThanks to the progress we have made this past year, COVID-19 need no longer control our lives.\n\nI know some are talking about “living with COVID-19”. Tonight – I say that we will never just accept living with COVID-19.\n\nWe will continue to combat the virus as we do other diseases. And because this is a virus that mutates and spreads, we will stay on guard.\n\nHere are four common sense steps as we move forward safely.\n\nFirst, stay protected with vaccines and treatments. We know how incredibly effective vaccines are. If you’re vaccinated and boosted you have the highest degree of protection.\n\nWe will never give up on vaccinating more Americans. Now, I know parents with kids under 5 are eager to see a vaccine authorized for their children.\n\nThe scientists are working hard to get that done and we’ll be ready with plenty of vaccines when they do.\n\nWe’re also ready with anti-viral treatments. If you get COVID-19, the Pfizer pill reduces your chances of ending up in the hospital by 90%.\n\nWe’ve ordered more of these pills than anyone in the world. And Pfizer is working overtime to get us 1 Million pills this month and more than double that next month.\n\nAnd we’re launching the “Test to Treat” initiative so people can get tested at a pharmacy, and if they’re positive, receive antiviral pills on the spot at no cost.\n\nIf you’re immunocompromised or have some other vulnerability, we have treatments and free high-quality masks.\n\nWe’re leaving no one behind or ignoring anyone’s needs as we move forward.\n\nAnd on testing, we have made hundreds of millions of tests available for you to order for free.\n\nEven if you already ordered free tests tonight, I am announcing that you can order more from covidtests.gov starting next week.\n\nSecond – we must prepare for new variants. Over the past year, we’ve gotten much better at detecting new variants.\n\nIf necessary, we’ll be able to deploy new vaccines within 100 days instead of many more months or years.\n\nAnd, if Congress provides the funds we need, we’ll have new stockpiles of tests, masks, and pills ready if needed.\n\nI cannot promise a new variant won’t come. But I can promise you we’ll do everything within our power to be ready if it does.\n\nThird – we can end the shutdown of schools and businesses. We have the tools we need.\n\nIt’s time for Americans to get back to work and fill our great downtowns again.  People working from home can feel safe to begin to return to the office.\n\nWe’re doing that here in the federal government. The vast majority of federal workers will once again work in person.\n\nOur schools are open. Let’s keep it that way. Our kids need to be in school.\n\nAnd with 75% of adult Americans fully vaccinated and hospitalizations down by 77%, most Americans can remove their masks, return to work, stay in the classroom, and move forward safely.\n\nWe achieved this because we provided free vaccines, treatments, tests, and masks.\n\nOf course, continuing this costs money.\n\nI will soon send Congress a request.\n\nThe vast majority of Americans have used these tools and may want to again, so I expect Congress to pass it quickly.\n\nFourth, we will continue vaccinating the world.\n\nWe’ve sent 475 Million vaccine doses to 112 countries, more than any other nation.\n\nAnd we won’t stop.\n\nWe have lost so much to COVID-19. Time with one another. And worst of all, so much loss of life.\n\nLet’s use this moment to reset. Let’s stop looking at COVID-19 as a partisan dividing line and see it for what it is: A God-awful disease.\n\nLet’s stop seeing each other as enemies, and start seeing each other for who we really are: Fellow Americans.\n\nWe can’t change how divided we’ve been. But we can change how we move forward—on COVID-19 and other issues we must face together.\n\nI recently visited the New York City Police Department days after the funerals of Officer Wilbert Mora and his partner, Officer Jason Rivera.\n\nThey were responding to a 9-1-1 call when a man shot and killed them with a stolen gun.\n\nOfficer Mora was 27 years old.\n\nOfficer Rivera was 22.\n\nBoth Dominican Americans who’d grown up on the same streets they later chose to patrol as police officers.\n\nI spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves.\n\nI’ve worked on these issues a long time.\n\nI know what works: Investing in crime preventionand community police officers who’ll walk the beat, who’ll know the neighborhood, and who can restore trust and safety.\n\nSo let’s not abandon our streets. Or choose between safety and equal justice.\n\nLet’s come together to protect our communities, restore trust, and hold law enforcement accountable.\n\nThat’s why the Justice Department required body cameras, banned chokeholds, and restricted no-knock warrants for its officers.\n\nThat’s why the American Rescue Plan provided $350 Billion that cities, states, and counties can use to hire more police and invest in proven strategies like community violence interruption—trusted messengers breaking the cycle of violence and trauma and giving young people hope.\n\nWe should all agree: The answer is not to Defund the police. The answer is to FUND the police with the resources and training they need to protect our communities.\n\nI ask Democrats and Republicans alike: Pass my budget and keep our neighborhoods safe.\n\nAnd I will keep doing everything in my power to crack down on gun trafficking and ghost guns you can buy online and make at home—they have no serial numbers and can’t be traced.\n\nAnd I ask Congress to pass proven measures to reduce gun violence. Pass universal background checks. Why should anyone on a terrorist list be able to purchase a weapon?\n\nBan assault weapons and high-capacity magazines.\n\nRepeal the liability shield that makes gun manufacturers the only industry in America that can’t be sued.\n\nThese laws don’t infringe on the Second Amendment. They save lives.\n\nThe most fundamental right in America is the right to vote – and to have it counted. And it’s under assault.\n\nIn state after state, new laws have been passed, not only to suppress the vote, but to subvert entire elections.\n\nWe cannot let this happen.\n\nTonight. I call on the Senate to: Pass the Freedom to Vote Act. Pass the John Lewis Voting Rights Act. And while you’re at it, pass the Disclose Act so Americans can know who is funding our elections.\n\nTonight, I’d like to honor someone who has dedicated his life to serve this country: Justice Stephen Breyer—an Army veteran, Constitutional scholar, and retiring Justice of the United States Supreme Court. Justice Breyer, thank you for your service.\n\nOne of the most serious constitutional responsibilities a President has is nominating someone to serve on the United States Supreme Court.\n\nAnd I did that 4 days ago, when I nominated Circuit Court of Appeals Judge Ketanji Brown Jackson. One of our nation’s top legal minds, who will continue Justice Breyer’s legacy of excellence.\n\nA former top litigator in private practice. A former federal public defender. And from a family of public school educators and police officers. A consensus builder. Since she’s been nominated, she’s received a broad range of support—from the Fraternal Order of Police to former judges appointed by Democrats and Republicans.\n\nAnd if we are to advance liberty and justice, we need to secure the Border and fix the immigration system.\n\nWe can do both. At our border, we’ve installed new technology like cutting-edge scanners to better detect drug smuggling.\n\nWe’ve set up joint patrols with Mexico and Guatemala to catch more human traffickers.\n\nWe’re putting in place dedicated immigration judges so families fleeing persecution and violence can have their cases heard faster.\n\nWe’re securing commitments and supporting partners in South and Central America to host more refugees and secure their own borders.\n\nWe can do all this while keeping lit the torch of liberty that has led generations of immigrants to this land—my forefathers and so many of yours.\n\nProvide a pathway to citizenship for Dreamers, those on temporary status, farm workers, and essential workers.\n\nRevise our laws so businesses have the workers they need and families don’t wait decades to reunite.\n\nIt’s not only the right thing to do—it’s the economically smart thing to do.\n\nThat’s why immigration reform is supported by everyone from labor unions to religious leaders to the U.S. Chamber of Commerce.\n\nLet’s get it done once and for all.\n\nAdvancing liberty and justice also requires protecting the rights of women.\n\nThe constitutional right affirmed in Roe v. Wade—standing precedent for half a century—is under attack as never before.\n\nIf we want to go forward—not backward—we must protect access to health care. Preserve a woman’s right to choose. And let’s continue to advance maternal health care in America.\n\nAnd for our LGBTQ+ Americans, let’s finally get the bipartisan Equality Act to my desk. The onslaught of state laws targeting transgender Americans and their families is wrong.\n\nAs I said last year, especially to our younger transgender Americans, I will always have your back as your President, so you can be yourself and reach your God-given potential.\n\nWhile it often appears that we never agree, that isn’t true. I signed 80 bipartisan bills into law last year. From preventing government shutdowns to protecting Asian-Americans from still-too-common hate crimes to reforming military justice.\n\nAnd soon, we’ll strengthen the Violence Against Women Act that I first wrote three decades ago. It is important for us to show the nation that we can come together and do big things.\n\nSo tonight I’m offering a Unity Agenda for the Nation. Four big things we can do together.\n\nFirst, beat the opioid epidemic.\n\nThere is so much we can do. Increase funding for prevention, treatment, harm reduction, and recovery.\n\nGet rid of outdated rules that stop doctors from prescribing treatments. And stop the flow of illicit drugs by working with state and local law enforcement to go after traffickers.\n\nIf you’re suffering from addiction, know you are not alone. I believe in recovery, and I celebrate the 23 million Americans in recovery.\n\nSecond, let’s take on mental health. Especially among our children, whose lives and education have been turned upside down.\n\nThe American Rescue Plan gave schools money to hire teachers and help students make up for lost learning.\n\nI urge every parent to make sure your school does just that. And we can all play a part—sign up to be a tutor or a mentor.\n\nChildren were also struggling before the pandemic. Bullying, violence, trauma, and the harms of social media.\n\nAs Frances Haugen, who is here with us tonight, has shown, we must hold social media platforms accountable for the national experiment they’re conducting on our children for profit.\n\nIt’s time to strengthen privacy protections, ban targeted advertising to children, demand tech companies stop collecting personal data on our children.\n\nAnd let’s get all Americans the mental health services they need. More people they can turn to for help, and full parity between physical and mental health care.\n\nThird, support our veterans.\n\nVeterans are the best of us.\n\nI’ve always believed that we have a sacred obligation to equip all those we send to war and care for them and their families when they come home.\n\nMy administration is providing assistance with job training and housing, and now helping lower-income veterans get VA care debt-free.\n\nOur troops in Iraq and Afghanistan faced many dangers.\n\nOne was stationed at bases and breathing in toxic smoke from “burn pits” that incinerated wastes of war—medical and hazard material, jet fuel, and more.\n\nWhen they came home, many of the world’s fittest and best trained warriors were never the same.\n\nHeadaches. Numbness. Dizziness.\n\nA cancer that would put them in a flag-draped coffin.\n\nI know.\n\nOne of those soldiers was my son Major Beau Biden.\n\nWe don’t know for sure if a burn pit was the cause of his brain cancer, or the diseases of so many of our troops.\n\nBut I’m committed to finding out everything we can.\n\nCommitted to military families like Danielle Robinson from Ohio.\n\nThe widow of Sergeant First Class Heath Robinson.\n\nHe was born a soldier. Army National Guard. Combat medic in Kosovo and Iraq.\n\nStationed near Baghdad, just yards from burn pits the size of football fields.\n\nHeath’s widow Danielle is here with us tonight. They loved going to Ohio State football games. He loved building Legos with their daughter.\n\nBut cancer from prolonged exposure to burn pits ravaged Heath’s lungs and body.\n\nDanielle says Heath was a fighter to the very end.\n\nHe didn’t know how to stop fighting, and neither did she.\n\nThrough her pain she found purpose to demand we do better.\n\nTonight, Danielle—we are.\n\nThe VA is pioneering new ways of linking toxic exposures to diseases, already helping more veterans get benefits.\n\nAnd tonight, I’m announcing we’re expanding eligibility to veterans suffering from nine respiratory cancers.\n\nI’m also calling on Congress: pass a law to make sure veterans devastated by toxic exposures in Iraq and Afghanistan finally get the benefits and comprehensive health care they deserve.\n\nAnd fourth, let’s end cancer as we know it.\n\nThis is personal to me and Jill, to Kamala, and to so many of you.\n\nCancer is the #2 cause of death in America–second only to heart disease.\n\nLast month, I announced our plan to supercharge\nthe Cancer Moonshot that President Obama asked me to lead six years ago.\n\nOur goal is to cut the cancer death rate by at least 50% over the next 25 years, turn more cancers from death sentences into treatable diseases.\n\nMore support for patients and families.\n\nTo get there, I call on Congress to fund ARPA-H, the Advanced Research Projects Agency for Health.\n\nIt’s based on DARPA—the Defense Department project that led to the Internet, GPS, and so much more.\n\nARPA-H will have a singular purpose—to drive breakthroughs in cancer, Alzheimer’s, diabetes, and more.\n\nA unity agenda for the nation.\n\nWe can do this.\n\nMy fellow Americans—tonight , we have gathered in a sacred space—the citadel of our democracy.\n\nIn this Capitol, generation after generation, Americans have debated great questions amid great strife, and have done great things.\n\nWe have fought for freedom, expanded liberty, defeated totalitarianism and terror.\n\nAnd built the strongest, freest, and most prosperous nation the world has ever known.\n\nNow is the hour.\n\nOur moment of responsibility.\n\nOur test of resolve and conscience, of history itself.\n\nIt is in this moment that our character is formed. Our purpose is found. Our future is forged.\n\nWell I know this nation.\n\nWe will meet the test.\n\nTo protect freedom and liberty, to expand fairness and opportunity.\n\nWe will save democracy.\n\nAs hard as these times have been, I am more optimistic about America today than I have been my whole life.\n\nBecause I see the future that is within our grasp.\n\nBecause I know there is simply nothing beyond our capacity.\n\nWe are the only nation on Earth that has always turned every crisis we have faced into an opportunity.\n\nThe only nation that can be defined by a single word: possibilities.\n\nSo on this night, in our 245th year as a nation, I have come to report on the State of the Union.\n\nAnd my report is this: the State of the Union is strong—because you, the American people, are strong.\n\nWe are stronger today than we were a year ago.\n\nAnd we will be stronger a year from now than we are today.\n\nNow is our moment to meet and overcome the challenges of our time.\n\nAnd we will, as one people.\n\nOne America.\n\nThe United States of America.\n\nMay God bless you all. May God protect our troops.\n"
  },
  {
    "path": "ios/.gitignore",
    "content": "xuserdata\nMLCSwift/tvm_home\n*~\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/Assets.xcassets/AccentColor.colorset/Contents.json",
    "content": "{\n  \"colors\" : [\n    {\n      \"idiom\" : \"universal\"\n    }\n  ],\n  \"info\" : {\n    \"author\" : \"xcode\",\n    \"version\" : 1\n  }\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/Assets.xcassets/AppIcon.appiconset/Contents.json",
    "content": "{\n  \"images\" : [\n    {\n      \"filename\" : \"mlc-logo.png\",\n      \"idiom\" : \"universal\",\n      \"platform\" : \"ios\",\n      \"size\" : \"1024x1024\"\n    }\n  ],\n  \"info\" : {\n    \"author\" : \"xcode\",\n    \"version\" : 1\n  }\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/Assets.xcassets/Contents.json",
    "content": "{\n  \"info\" : {\n    \"author\" : \"xcode\",\n    \"version\" : 1\n  }\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/Common/Constants.swift",
    "content": "//\n//  Constants.swift\n//  MLCChat\n//\n\nstruct Constants {\n    static let prebuiltModelDir = \"bundle\"\n    static let appConfigFileName = \"bundle/mlc-app-config.json\"\n    static let modelConfigFileName = \"mlc-chat-config.json\"\n    static let paramsConfigFileName = \"tensor-cache.json\"\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/Info.plist",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!DOCTYPE plist PUBLIC \"-//Apple//DTD PLIST 1.0//EN\" \"http://www.apple.com/DTDs/PropertyList-1.0.dtd\">\n<plist version=\"1.0\">\n<dict/>\n</plist>\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/MLCChat.entitlements",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!DOCTYPE plist PUBLIC \"-//Apple//DTD PLIST 1.0//EN\" \"http://www.apple.com/DTDs/PropertyList-1.0.dtd\">\n<plist version=\"1.0\">\n<dict>\n    <key>com.apple.developer.kernel.extended-virtual-addressing</key>\n    <true/>\n    <key>com.apple.developer.kernel.increased-memory-limit</key>\n    <true/>\n</dict>\n</plist>\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/MLCChatApp.swift",
    "content": "//\n//  MLCChatApp.swift\n//  MLCChat\n//\n//  Created by Tianqi Chen on 4/26/23.\n//\n\nimport SwiftUI\n\n@main\nstruct MLCChatApp: App {\n    @StateObject private var appState = AppState()\n\n    init() {\n        UITableView.appearance().separatorStyle = .none\n        UITableView.appearance().tableFooterView = UIView()\n    }\n\n    var body: some Scene {\n        WindowGroup {\n            StartView()\n                .environmentObject(appState)\n                .task {\n                    appState.loadAppConfigAndModels()\n                }\n        }\n    }\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/Models/AppConfig.swift",
    "content": "//\n//  AppConfig.swift\n//  MLCChat\n//\n\nstruct AppConfig: Codable {\n    struct ModelRecord: Codable {\n        let modelPath: String?\n        let modelURL: String?\n        let modelLib: String\n        let estimatedVRAMReq: Int\n        let modelID: String\n\n        enum CodingKeys: String, CodingKey {\n            case modelPath = \"model_path\"\n            case modelURL = \"model_url\"\n            case modelLib = \"model_lib\"\n            case estimatedVRAMReq = \"estimated_vram_bytes\"\n            case modelID = \"model_id\"\n        }\n    }\n\n    var modelList: [ModelRecord]\n\n    enum CodingKeys: String, CodingKey {\n        case modelList = \"model_list\"\n    }\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/Models/ModelConfig.swift",
    "content": "//\n//  ModelConfig.swift\n//  MLCChat\n//\n\nstruct ModelConfig: Decodable {\n    let tokenizerFiles: [String]\n    var modelLib: String?\n    var modelID: String?\n    var estimatedVRAMReq: Int?\n\n    enum CodingKeys: String, CodingKey {\n        case tokenizerFiles = \"tokenizer_files\"\n        case modelLib = \"model_lib\"\n        case modelID = \"model_id\"\n        case estimatedVRAMReq = \"estimated_vram_req\"\n    }\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/Models/ParamsConfig.swift",
    "content": "//\n//  ParamsConfig.swift\n//  MLCChat\n//\n\nstruct ParamsConfig: Decodable {\n    struct ParamsRecord: Decodable {\n        let dataPath: String\n    }\n\n    let records: [ParamsRecord]\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/Preview Content/Preview Assets.xcassets/Contents.json",
    "content": "{\n  \"info\" : {\n    \"author\" : \"xcode\",\n    \"version\" : 1\n  }\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/States/AppState.swift",
    "content": "//\n//  AppState.swift\n//  MLCChat\n//\n//  Created by Yaxing Cai on 5/13/23.\n//\n\nimport Foundation\n\nfinal class AppState: ObservableObject {\n    @Published var models = [ModelState]()\n    @Published var chatState = ChatState()\n\n    @Published var alertMessage = \"\" // TODO: Should move out\n    @Published var alertDisplayed = false // TODO: Should move out\n\n    private var appConfig: AppConfig?\n    private var modelIDs = Set<String>()\n\n    private let fileManager: FileManager = FileManager.default\n    private lazy var cacheDirectoryURL: URL = {\n        fileManager.urls(for: .cachesDirectory, in: .userDomainMask)[0]\n    }()\n\n    private let jsonDecoder = JSONDecoder()\n    private let jsonEncoder = JSONEncoder()\n\n    func loadAppConfigAndModels() {\n        appConfig = loadAppConfig()\n        // Can't do anything without a valid app config\n        guard let appConfig else {\n            return\n        }\n        loadModelsConfig(modelList: appConfig.modelList)\n    }\n\n    func requestDeleteModel(modelID: String) {\n        // model dir should have been deleted in ModelState\n        assert(!fileManager.fileExists(atPath: cacheDirectoryURL.appending(path: modelID).path()))\n        modelIDs.remove(modelID)\n        models.removeAll(where: {$0.modelConfig.modelID == modelID})\n        updateAppConfig {\n            appConfig?.modelList.removeAll(where: {$0.modelID == modelID})\n        }\n    }\n}\n\nprivate extension AppState {\n    func loadAppConfig() -> AppConfig? {\n        // models in cache to download\n        var appConfigFileURL = cacheDirectoryURL.appending(path: Constants.appConfigFileName)\n        if !fileManager.fileExists(atPath: appConfigFileURL.path()) {\n            appConfigFileURL = Bundle.main.bundleURL.appending(path: Constants.appConfigFileName)\n        }\n        assert(fileManager.fileExists(atPath: appConfigFileURL.path()))\n\n        do {\n            let fileHandle = try FileHandle(forReadingFrom: appConfigFileURL)\n            let data = fileHandle.readDataToEndOfFile()\n\n            let appConfig = try jsonDecoder.decode(AppConfig.self, from: data)\n            return appConfig\n        } catch {\n            showAlert(message: \"Failed to load app config: \\(error.localizedDescription)\")\n            return nil\n        }\n    }\n\n    func loadModelsConfig(modelList: [AppConfig.ModelRecord]) {\n        for model in modelList {\n            if model.modelPath != nil {\n                // local model\n                let modelDir = Bundle.main.bundleURL.appending(path: Constants.prebuiltModelDir).appending(path: model.modelPath!)\n                let modelConfigURL = modelDir.appending(path: Constants.modelConfigFileName)\n                if fileManager.fileExists(atPath: modelConfigURL.path()) {\n                    if let modelConfig = loadModelConfig(\n                        modelConfigURL: modelConfigURL,\n                        modelLib: model.modelLib,\n                        modelID: model.modelID,\n                        estimatedVRAMReq: model.estimatedVRAMReq\n                    ) {\n                        addModelConfig(\n                            modelConfig: modelConfig,\n                            modelPath: model.modelPath!,\n                            modelURL: nil,\n                            isBuiltin: true\n                        )\n                    } else {\n                        showAlert(message: \"Failed to load prebuilt model: \\(model.modelPath!)\")\n                    }\n                } else {\n                    showAlert(message: \"Prebuilt mlc-chat-config.json file not found: \\(model.modelPath!)\")\n                }\n            } else if model.modelURL != nil {\n                // remote model\n                let modelConfigFileURL = cacheDirectoryURL\n                    .appending(path: model.modelID)\n                    .appending(path: Constants.modelConfigFileName)\n                if fileManager.fileExists(atPath: modelConfigFileURL.path()) {\n                    if let modelConfig = loadModelConfig(\n                        modelConfigURL: modelConfigFileURL,\n                        modelLib: model.modelLib,\n                        modelID: model.modelID,\n                        estimatedVRAMReq: model.estimatedVRAMReq\n                    ) {\n                        addModelConfig(\n                            modelConfig: modelConfig,\n                            modelPath: nil,\n                            modelURL: URL(string: model.modelURL!),\n                            isBuiltin: true\n                        )\n                    }\n                } else {\n                    downloadConfig(\n                        modelURL: URL(string: model.modelURL!),\n                        modelLib: model.modelLib,\n                        modelID: model.modelID,\n                        estimatedVRAMReq: model.estimatedVRAMReq,\n                        isBuiltin: true\n                    )\n                }\n            } else {\n                showAlert(message: \"Path or URL should be provided in app config: \\(model.modelID)\")\n            }\n        }\n    }\n\n    func loadModelConfig(modelConfigURL: URL, modelLib: String, modelID: String, estimatedVRAMReq: Int) -> ModelConfig? {\n        do {\n            assert(fileManager.fileExists(atPath: modelConfigURL.path()))\n            let fileHandle = try FileHandle(forReadingFrom: modelConfigURL)\n            let data = fileHandle.readDataToEndOfFile()\n            var modelConfig = try jsonDecoder.decode(ModelConfig.self, from: data)\n            modelConfig.modelLib = modelLib\n            modelConfig.modelID = modelID\n            modelConfig.estimatedVRAMReq = estimatedVRAMReq\n            return modelConfig\n        } catch {\n            showAlert(message: \"Failed to resolve model config: \\(error.localizedDescription)\")\n        }\n        return nil\n    }\n\n    func showAlert(message: String) {\n        DispatchQueue.main.async { [weak self] in\n            guard let self = self else { return }\n            if !self.alertDisplayed {\n                self.alertMessage = message\n                self.alertDisplayed = true\n            } else {\n                self.alertMessage.append(\"\\n\" + message)\n            }\n        }\n    }\n\n    func downloadConfig(modelURL: URL?, modelLib: String, modelID: String, estimatedVRAMReq: Int, isBuiltin: Bool) {\n        guard let modelConfigURL = modelURL?.appending(path: \"resolve\").appending(path: \"main\").appending(path: Constants.modelConfigFileName) else {\n            return\n        }\n\n        let downloadTask = URLSession.shared.downloadTask(with: modelConfigURL) {\n            [weak self] urlOrNil, responseOrNil, errorOrNil in\n            guard let self else {\n                return\n            }\n            if let error = errorOrNil {\n                self.showAlert(message: \"Failed to download model config: \\(error.localizedDescription)\")\n                return\n            }\n            guard let fileUrl = urlOrNil else {\n                self.showAlert(message: \"Failed to download model config\")\n                return\n            }\n\n            // cache temp file to avoid being deleted by system automatically\n            let tempName = UUID().uuidString\n            let tempFileURL = self.cacheDirectoryURL.appending(path: tempName)\n\n            do {\n                try self.fileManager.moveItem(at: fileUrl, to: tempFileURL)\n            } catch {\n                self.showAlert(message: \"Failed to cache downloaded file: \\(error.localizedDescription)\")\n                return\n            }\n\n            do {\n                guard let modelConfig = loadModelConfig(\n                    modelConfigURL: tempFileURL,\n                    modelLib: modelLib,\n                    modelID: modelID,\n                    estimatedVRAMReq: estimatedVRAMReq\n                ) else {\n                    try fileManager.removeItem(at: tempFileURL)\n                    return\n                }\n\n                if modelIDs.contains(modelConfig.modelID!) {\n                    try fileManager.removeItem(at: tempFileURL)\n                    return\n                }\n\n                let modelBaseUrl = cacheDirectoryURL.appending(path: modelConfig.modelID!)\n                try fileManager.createDirectory(at: modelBaseUrl, withIntermediateDirectories: true)\n                let modelConfigUrl = modelBaseUrl.appending(path: Constants.modelConfigFileName)\n                try fileManager.moveItem(at: tempFileURL, to: modelConfigUrl)\n                assert(fileManager.fileExists(atPath: modelConfigUrl.path()))\n                assert(!fileManager.fileExists(atPath: tempFileURL.path()))\n                addModelConfig(\n                    modelConfig: modelConfig,\n                    modelPath: nil,\n                    modelURL: modelURL,\n                    isBuiltin: isBuiltin\n                )\n            } catch {\n                showAlert(message: \"Failed to import model: \\(error.localizedDescription)\")\n            }\n        }\n        downloadTask.resume()\n    }\n\n    func addModelConfig(modelConfig: ModelConfig, modelPath: String?, modelURL: URL?, isBuiltin: Bool) {\n        assert(!modelIDs.contains(modelConfig.modelID!))\n        modelIDs.insert(modelConfig.modelID!)\n        let modelBaseURL: URL\n\n        // model_id dir should exist\n        if modelURL == nil {\n            // prebuilt model in bundle\n            modelBaseURL = Bundle.main.bundleURL.appending(path: Constants.prebuiltModelDir).appending(path: modelPath!)\n        } else {\n            // download model in cache\n            modelBaseURL = cacheDirectoryURL.appending(path: modelConfig.modelID!)\n        }\n        assert(fileManager.fileExists(atPath: modelBaseURL.path()))\n\n        // mlc-chat-config.json should exist\n        let modelConfigURL = modelBaseURL.appending(path: Constants.modelConfigFileName)\n        assert(fileManager.fileExists(atPath: modelConfigURL.path()))\n\n        let model = ModelState(modelConfig: modelConfig, modelLocalBaseURL: modelBaseURL, startState: self, chatState: chatState)\n        model.checkModelDownloadState(modelURL: modelURL)\n\n        // addModelConfig is not called from main thread, update to models needs to be performed on main\n        DispatchQueue.main.async { [weak self] in\n            guard let self = self else { return }\n            models.append(model)\n        }\n\n        if modelURL != nil && !isBuiltin {\n            updateAppConfig {\n                appConfig?.modelList.append(\n                    AppConfig.ModelRecord(\n                        modelPath: nil,\n                        modelURL: modelURL!.absoluteString,\n                        modelLib: modelConfig.modelLib!,\n                        estimatedVRAMReq: modelConfig.estimatedVRAMReq!,\n                        modelID: modelConfig.modelID!\n                    )\n                )\n            }\n        }\n    }\n\n    func updateAppConfig(action: () -> Void) {\n        action()\n        let appConfigURL = cacheDirectoryURL.appending(path: Constants.appConfigFileName)\n        do {\n            let data = try jsonEncoder.encode(appConfig)\n            try data.write(to: appConfigURL, options: Data.WritingOptions.atomic)\n        } catch {\n            print(error.localizedDescription)\n        }\n    }\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/States/ChatState.swift",
    "content": "//\n//  ChatState.swift\n//  LLMChat\n//\n\nimport Foundation\nimport MLCSwift\n\nenum MessageRole {\n    case user\n    case assistant\n}\n\nextension MessageRole {\n    var isUser: Bool { self == .user }\n}\n\nstruct MessageData: Hashable {\n    let id = UUID()\n    var role: MessageRole\n    var message: String\n}\n\nfinal class ChatState: ObservableObject {\n    fileprivate enum ModelChatState {\n        case generating\n        case resetting\n        case reloading\n        case terminating\n        case ready\n        case failed\n        case pendingImageUpload\n        case processingImage\n    }\n\n    @Published var displayMessages = [MessageData]()\n    @Published var infoText = \"\"\n    @Published var displayName = \"\"\n    // this is a legacy UI option for upload image\n    // TODO(mlc-team) support new UI for image processing\n    @Published var legacyUseImage = false\n\n    private let modelChatStateLock = NSLock()\n    private var modelChatState: ModelChatState = .ready\n\n    // the new mlc engine\n    private let engine = MLCEngine()\n    // history messages\n    private var historyMessages = [ChatCompletionMessage]()\n\n    // streaming text that get updated\n    private var streamingText = \"\"\n\n    private var modelLib = \"\"\n    private var modelPath = \"\"\n    var modelID = \"\"\n\n    init() {\n    }\n\n    var isInterruptible: Bool {\n        return getModelChatState() == .ready\n        || getModelChatState() == .generating\n        || getModelChatState() == .failed\n        || getModelChatState() == .pendingImageUpload\n    }\n\n    var isChattable: Bool {\n        return getModelChatState() == .ready\n    }\n\n    var isUploadable: Bool {\n        return getModelChatState() == .pendingImageUpload\n    }\n\n    var isResettable: Bool {\n        return getModelChatState() == .ready\n        || getModelChatState() == .generating\n    }\n\n    func requestResetChat() {\n        assert(isResettable)\n        interruptChat(prologue: {\n            switchToResetting()\n        }, epilogue: { [weak self] in\n            self?.mainResetChat()\n        })\n    }\n\n    // reset the chat if we switch to background\n    // during generation to avoid permission issue\n    func requestSwitchToBackground() {\n        if (getModelChatState() == .generating) {\n            self.requestResetChat()\n        }\n    }\n\n\n    func requestTerminateChat(callback: @escaping () -> Void) {\n        assert(isInterruptible)\n        interruptChat(prologue: {\n            switchToTerminating()\n        }, epilogue: { [weak self] in\n            self?.mainTerminateChat(callback: callback)\n        })\n    }\n\n    func requestReloadChat(modelID: String, modelLib: String, modelPath: String, estimatedVRAMReq: Int, displayName: String) {\n        if (isCurrentModel(modelID: modelID)) {\n            return\n        }\n        assert(isInterruptible)\n        interruptChat(prologue: {\n            switchToReloading()\n        }, epilogue: { [weak self] in\n            self?.mainReloadChat(modelID: modelID,\n                                 modelLib: modelLib,\n                                 modelPath: modelPath,\n                                 estimatedVRAMReq: estimatedVRAMReq,\n                                 displayName: displayName)\n        })\n    }\n\n\n    func requestGenerate(prompt: String) {\n        assert(isChattable)\n        switchToGenerating()\n        appendMessage(role: .user, message: prompt)\n        appendMessage(role: .assistant, message: \"\")\n\n        Task {\n            self.historyMessages.append(\n                ChatCompletionMessage(role: .user, content: prompt)\n            )\n            var finishReasonLength = false\n            var finalUsageTextLabel = \"\"\n\n            for await res in await engine.chat.completions.create(\n                messages: self.historyMessages,\n                stream_options: StreamOptions(include_usage: true)\n            ) {\n                for choice in res.choices {\n                    if let content = choice.delta.content {\n                        self.streamingText += content.asText()\n                    }\n                    if let finish_reason = choice.finish_reason {\n                        if finish_reason == \"length\" {\n                            finishReasonLength = true\n                        }\n                    }\n                }\n                if let finalUsage = res.usage {\n                    finalUsageTextLabel = finalUsage.extra?.asTextLabel() ?? \"\"\n                }\n                if getModelChatState() != .generating {\n                    break\n                }\n\n                var updateText = self.streamingText\n                if finishReasonLength {\n                    updateText += \" [output truncated due to context length limit...]\"\n                }\n\n                let newText = updateText\n                DispatchQueue.main.async {\n                    self.updateMessage(role: .assistant, message: newText)\n                }\n            }\n\n            // record history messages\n            if !self.streamingText.isEmpty {\n                self.historyMessages.append(\n                    ChatCompletionMessage(role: .assistant, content: self.streamingText)\n                )\n                // stream text can be cleared\n                self.streamingText = \"\"\n            } else {\n                self.historyMessages.removeLast()\n            }\n\n            // if we exceed history\n            // we can try to reduce the history and see if it can fit\n            if (finishReasonLength) {\n                let windowSize = self.historyMessages.count\n                assert(windowSize % 2 == 0)\n                let removeEnd = ((windowSize + 3) / 4) * 2\n                self.historyMessages.removeSubrange(0..<removeEnd)\n            }\n\n            if getModelChatState() == .generating {\n                let runtimStats = finalUsageTextLabel\n\n                DispatchQueue.main.async {\n                    self.infoText = runtimStats\n                    self.switchToReady()\n\n                }\n            }\n        }\n    }\n\n    func isCurrentModel(modelID: String) -> Bool {\n        return self.modelID == modelID\n    }\n}\n\nprivate extension ChatState {\n    func getModelChatState() -> ModelChatState {\n        modelChatStateLock.lock()\n        defer { modelChatStateLock.unlock() }\n        return modelChatState\n    }\n\n    func setModelChatState(_ newModelChatState: ModelChatState) {\n        modelChatStateLock.lock()\n        modelChatState = newModelChatState\n        modelChatStateLock.unlock()\n    }\n\n    func appendMessage(role: MessageRole, message: String) {\n        displayMessages.append(MessageData(role: role, message: message))\n    }\n\n    func updateMessage(role: MessageRole, message: String) {\n        displayMessages[displayMessages.count - 1] = MessageData(role: role, message: message)\n    }\n\n    func clearHistory() {\n        displayMessages.removeAll()\n        infoText = \"\"\n        historyMessages.removeAll()\n        streamingText = \"\"\n    }\n\n    func switchToResetting() {\n        setModelChatState(.resetting)\n    }\n\n    func switchToGenerating() {\n        setModelChatState(.generating)\n    }\n\n    func switchToReloading() {\n        setModelChatState(.reloading)\n    }\n\n    func switchToReady() {\n        setModelChatState(.ready)\n    }\n\n    func switchToTerminating() {\n        setModelChatState(.terminating)\n    }\n\n    func switchToFailed() {\n        setModelChatState(.failed)\n    }\n\n    func switchToPendingImageUpload() {\n        setModelChatState(.pendingImageUpload)\n    }\n\n    func switchToProcessingImage() {\n        setModelChatState(.processingImage)\n    }\n\n    func interruptChat(prologue: () -> Void, epilogue: @escaping () -> Void) {\n        assert(isInterruptible)\n        if getModelChatState() == .ready\n            || getModelChatState() == .failed\n            || getModelChatState() == .pendingImageUpload {\n            prologue()\n            epilogue()\n        } else if getModelChatState() == .generating {\n            prologue()\n            DispatchQueue.main.async {\n                epilogue()\n            }\n        } else {\n            assert(false)\n        }\n    }\n\n    func mainResetChat() {\n        Task {\n            await engine.reset()\n            self.historyMessages = []\n            self.streamingText = \"\"\n\n            DispatchQueue.main.async {\n                self.clearHistory()\n                self.switchToReady()\n            }\n        }\n    }\n\n    func mainTerminateChat(callback: @escaping () -> Void) {\n        Task {\n            await engine.unload()\n            DispatchQueue.main.async {\n                self.clearHistory()\n                self.modelID = \"\"\n                self.modelLib = \"\"\n                self.modelPath = \"\"\n                self.displayName = \"\"\n                self.legacyUseImage = false\n                self.switchToReady()\n                callback()\n            }\n        }\n    }\n\n    func mainReloadChat(modelID: String, modelLib: String, modelPath: String, estimatedVRAMReq: Int, displayName: String) {\n        clearHistory()\n        self.modelID = modelID\n        self.modelLib = modelLib\n        self.modelPath = modelPath\n        self.displayName = displayName\n\n        Task {\n            DispatchQueue.main.async {\n                self.appendMessage(role: .assistant, message: \"[System] Initalize...\")\n            }\n\n            await engine.unload()\n            let vRAM = os_proc_available_memory()\n            if (vRAM < estimatedVRAMReq) {\n                let requiredMemory = String (\n                    format: \"%.1fMB\", Double(estimatedVRAMReq) / Double(1 << 20)\n                )\n                let errorMessage = (\n                    \"Sorry, the system cannot provide \\(requiredMemory) VRAM as requested to the app, \" +\n                    \"so we cannot initialize this model on this device.\"\n                )\n                DispatchQueue.main.sync {\n                    self.displayMessages.append(MessageData(role: MessageRole.assistant, message: errorMessage))\n                    self.switchToFailed()\n                }\n                return\n            }\n            await engine.reload(\n                modelPath: modelPath, modelLib: modelLib\n            )\n\n            // run a simple prompt with empty content to warm up system prompt\n            // helps to start things before user start typing\n            for await _ in await engine.chat.completions.create(\n                messages: [ChatCompletionMessage(role: .user, content: \"\")],\n                max_tokens: 1\n            ) {}\n\n            // TODO(mlc-team) run a system message prefill\n            DispatchQueue.main.async {\n                self.updateMessage(role: .assistant, message: \"[System] Ready to chat\")\n                self.switchToReady()\n            }\n\n        }\n    }\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/States/ModelState.swift",
    "content": "//\n//  ModelState.swift\n//  MLCChat\n//\n\nimport Foundation\n\nfinal class ModelState: ObservableObject, Identifiable {\n    enum ModelDownloadState {\n        case initializing\n        case indexing\n        case paused\n        case downloading\n        case pausing\n        case verifying\n        case finished\n        case failed\n        case clearing\n        case deleting\n    }\n\n    fileprivate struct DownloadTask: Hashable {\n        let remoteURL: URL\n        let localURL: URL\n    }\n\n    @Published var modelConfig: ModelConfig\n    @Published var modelDownloadState: ModelDownloadState = .initializing\n    @Published var progress: Int = 0\n    @Published var total: Int = 1\n\n    private var modelLocalBaseURL: URL\n    private var startState: AppState\n    private var chatState: ChatState\n\n    private let fileManager: FileManager = FileManager.default\n    private let decoder = JSONDecoder()\n    private var paramsConfig: ParamsConfig?\n    private var modelRemoteBaseURL: URL?\n    private var remainingTasks: Set<DownloadTask> = Set<DownloadTask>()\n    private var downloadingTasks: Set<DownloadTask> = Set<DownloadTask>()\n    private var maxDownloadingTasks: Int = 3\n\n    init(modelConfig: ModelConfig,\n         modelLocalBaseURL: URL,\n         startState: AppState,\n         chatState: ChatState) {\n        self.modelConfig = modelConfig\n        self.modelLocalBaseURL = modelLocalBaseURL\n        self.startState = startState\n        self.chatState = chatState\n    }\n\n    func checkModelDownloadState(modelURL: URL?) {\n        createModelFolderIfNeeded()\n\n        guard let modelURL else {\n            switchToVerifying()\n            return\n        }\n\n        modelRemoteBaseURL = modelURL.appending(path: \"resolve\").appending(path: \"main\")\n\n        // create local params dir\n        let paramsConfigURL = modelLocalBaseURL.appending(path: Constants.paramsConfigFileName)\n        if fileManager.fileExists(atPath: paramsConfigURL.path()) {\n            // tensor-cache.json already downloaded\n            loadParamsConfig()\n            switchToIndexing()\n        } else {\n            // download tensor-cache.json\n            downloadParamsConfig()\n        }\n    }\n\n    func startChat(chatState: ChatState) {\n        chatState.requestReloadChat(\n            modelID: modelConfig.modelID!,\n            modelLib: modelConfig.modelLib!,\n            modelPath: modelLocalBaseURL.path(),\n            estimatedVRAMReq: modelConfig.estimatedVRAMReq!,\n            displayName: modelConfig.modelID!.components(separatedBy: \"-\")[0]\n        )\n    }\n\n    func handleStart() {\n        // start downloading\n        switchToDownloading()\n    }\n\n    func handlePause() {\n        // pause downloading\n        switchToPausing()\n    }\n\n    func handleClear() {\n        assert(modelDownloadState == .downloading || modelDownloadState == .paused || modelDownloadState == .finished)\n        switchToClearing()\n    }\n\n    func handleDelete() {\n        assert(modelDownloadState == .downloading || modelDownloadState == .paused || modelDownloadState == .finished || modelDownloadState == .failed)\n        switchToDeleting()\n    }\n}\n\nprivate extension ModelState {\n    func createModelFolderIfNeeded() {\n        if !fileManager.fileExists(atPath: modelLocalBaseURL.path()) {\n            do {\n                try fileManager.createDirectory(at: modelLocalBaseURL, withIntermediateDirectories: true)\n            } catch {\n                print(error.localizedDescription)\n            }\n        }\n    }\n\n    func loadParamsConfig() {\n        let paramsConfigURL = modelLocalBaseURL.appending(path: Constants.paramsConfigFileName)\n        assert(fileManager.fileExists(atPath: paramsConfigURL.path()))\n        do {\n            let fileHandle = try FileHandle(forReadingFrom: paramsConfigURL)\n            let data = fileHandle.readDataToEndOfFile()\n            paramsConfig = try self.decoder.decode(ParamsConfig.self, from: data)\n        } catch {\n            print(error.localizedDescription)\n        }\n    }\n\n    func downloadParamsConfig() {\n        guard let modelRemoteBaseURL else {\n            return\n        }\n\n        let paramsConfigURL = modelLocalBaseURL.appending(path: Constants.paramsConfigFileName)\n        let downloadTask = URLSession.shared.downloadTask(with: modelRemoteBaseURL.appending(path: Constants.paramsConfigFileName)) {\n            [weak self] urlOrNil, responseOrNil, errorOrNil in\n            guard let self else { return }\n            guard let fileURL = urlOrNil else { return }\n            do {\n                try? self.fileManager.removeItem(at: paramsConfigURL)\n                try self.fileManager.moveItem(at: fileURL, to: paramsConfigURL)\n                DispatchQueue.main.async {\n                    self.loadParamsConfig()\n                    self.switchToIndexing()\n                }\n            } catch {\n                print(error.localizedDescription)\n            }\n        }\n        downloadTask.resume()\n    }\n\n    func switchToIndexing() {\n        guard let paramsConfig, let modelRemoteBaseURL else {\n            return\n        }\n\n        modelDownloadState = .indexing\n        progress = 0\n        total = modelConfig.tokenizerFiles.count + paramsConfig.records.count\n\n        // collect tokenizer download tasks\n        for tokenizerFile in modelConfig.tokenizerFiles {\n            let remoteURL = modelRemoteBaseURL.appending(path: tokenizerFile)\n            let localURL = modelLocalBaseURL.appending(path: tokenizerFile)\n\n            if fileManager.fileExists(atPath: localURL.path()) {\n                progress += 1\n            } else {\n                remainingTasks.insert(DownloadTask(remoteURL: remoteURL, localURL: localURL))\n            }\n        }\n\n        // collect params download tasks\n        for paramsRecord in paramsConfig.records {\n            let remoteURL = modelRemoteBaseURL.appending(path: paramsRecord.dataPath)\n            let localURL = modelLocalBaseURL.appending(path: paramsRecord.dataPath)\n\n            if fileManager.fileExists(atPath: localURL.path()) {\n                progress += 1\n            } else {\n                remainingTasks.insert(DownloadTask(remoteURL: remoteURL, localURL: localURL))\n            }\n        }\n\n        if progress < total {\n            switchToPaused()\n        } else {\n            switchToFinished()\n        }\n    }\n\n    func handleNewDownload(downloadTask: DownloadTask) {\n        // start one download task\n        assert(downloadingTasks.count < maxDownloadingTasks)\n        let task = URLSession.shared.downloadTask(with: downloadTask.remoteURL) {\n            [weak self] urlOrNil, responseOrNil, errorOrNil in\n            guard let self else { return }\n            guard let fileUrl = urlOrNil else {\n                DispatchQueue.main.async {\n                    self.handleCancelDownload(downloadTask: downloadTask)\n                }\n                return\n            }\n\n            do {\n                try self.fileManager.createDirectory(at: downloadTask.localURL.deletingLastPathComponent(), withIntermediateDirectories: true)\n                try? self.fileManager.removeItem(at: downloadTask.localURL)\n                try self.fileManager.moveItem(at: fileUrl, to: downloadTask.localURL)\n            } catch {\n                print(error.localizedDescription)\n            }\n            DispatchQueue.main.async {\n                self.handleFinishDownload(downloadTask: downloadTask)\n            }\n        }\n        downloadingTasks.insert(downloadTask)\n        task.resume()\n    }\n\n    func handleFinishDownload(downloadTask: DownloadTask) {\n        // update the finished download task\n        remainingTasks.remove(downloadTask)\n        downloadingTasks.remove(downloadTask)\n        progress += 1\n        assert(modelDownloadState == .downloading ||\n               modelDownloadState == .pausing ||\n               modelDownloadState == .clearing ||\n               modelDownloadState == .deleting\n        )\n        if modelDownloadState == .downloading {\n            if remainingTasks.isEmpty && downloadingTasks.isEmpty {\n                switchToFinished()\n            } else {\n                handleNextDownload()\n            }\n        } else if modelDownloadState == .pausing && downloadingTasks.isEmpty {\n            switchToPaused()\n        } else if modelDownloadState == .clearing && downloadingTasks.isEmpty {\n            clear()\n        } else if modelDownloadState == .deleting && downloadingTasks.isEmpty {\n            delete()\n        }\n    }\n\n    func handleCancelDownload(downloadTask: DownloadTask) {\n        // withdraw the failed download task\n        assert(modelDownloadState == .downloading || modelDownloadState == .pausing)\n        downloadingTasks.remove(downloadTask)\n        if modelDownloadState == .downloading {\n            handleNextDownload()\n        } else if modelDownloadState == .pausing && downloadingTasks.count == 0 {\n            switchToPaused()\n        }\n    }\n\n    func handleNextDownload() {\n        // start next download task\n        assert(modelDownloadState == .downloading)\n        for downloadTask in remainingTasks {\n            if !downloadingTasks.contains(downloadTask) {\n                handleNewDownload(downloadTask: downloadTask)\n                break\n            }\n        }\n    }\n\n    func switchToPaused() {\n        modelDownloadState = .paused\n    }\n\n    func switchToPausing() {\n        modelDownloadState = .pausing\n    }\n\n    func switchToVerifying() {\n        modelDownloadState = .verifying\n\n        let paramsConfigURL = modelLocalBaseURL.appending(path: Constants.paramsConfigFileName)\n        guard fileManager.fileExists(atPath: paramsConfigURL.path()) else {\n            switchToFailed()\n            return\n        }\n\n        loadParamsConfig()\n        guard let paramsConfig else {\n            switchToFailed()\n            return\n        }\n        progress = 0\n        total = modelConfig.tokenizerFiles.count + paramsConfig.records.count\n\n        if !verifyTokenizers() {\n            switchToFailed()\n            return\n        }\n\n        if !verifyParams() {\n            switchToFailed()\n            return\n        }\n\n        switchToFinished()\n    }\n\n    func verifyTokenizers() -> Bool {\n        for tokenizerFile in modelConfig.tokenizerFiles {\n            let localURL = modelLocalBaseURL.appending(path: tokenizerFile)\n\n            if !fileManager.fileExists(atPath: localURL.path()) {\n                switchToFailed()\n                return false\n            }\n            progress += 1\n        }\n        return true\n    }\n\n    func verifyParams() -> Bool {\n        guard let paramsConfig else {\n            return false\n        }\n\n        for paramsRecord in paramsConfig.records {\n            let localUrl = modelLocalBaseURL.appending(path: paramsRecord.dataPath)\n\n            if !fileManager.fileExists(atPath: localUrl.path()) {\n                switchToFailed()\n                return false\n            }\n\n            progress += 1\n        }\n        return true\n    }\n\n    func switchToClearing() {\n        if modelDownloadState == .paused {\n            modelDownloadState = .clearing\n            clear()\n        } else if modelDownloadState == .finished {\n            if chatState.modelID == modelConfig.modelID {\n                chatState.requestTerminateChat { [weak self] in\n                    self?.clear()\n                }\n            } else {\n                clear()\n            }\n        } else {\n            modelDownloadState = .clearing\n        }\n    }\n\n    func switchToDeleting() {\n        if modelDownloadState == .paused || modelDownloadState == .failed {\n            modelDownloadState = .deleting\n            delete()\n        } else if modelDownloadState == .finished {\n            if chatState.modelID == modelConfig.modelID {\n                chatState.requestTerminateChat { [weak self] in\n                    self?.delete()\n                }\n            } else {\n                delete()\n            }\n        } else {\n            modelDownloadState = .deleting\n        }\n    }\n\n    func switchToFinished() {\n        modelDownloadState = .finished\n    }\n\n    func switchToFailed() {\n        modelDownloadState = .failed\n    }\n\n    func switchToDownloading() {\n        modelDownloadState = .downloading\n        for downloadTask in remainingTasks {\n            if downloadingTasks.count < maxDownloadingTasks {\n                handleNewDownload(downloadTask: downloadTask)\n            } else {\n                return\n            }\n        }\n    }\n\n    func clear() {\n        do {\n            let fileURLs = try fileManager.contentsOfDirectory(at: modelLocalBaseURL, includingPropertiesForKeys: nil)\n            for fileURL in fileURLs where fileURL.lastPathComponent != Constants.modelConfigFileName {\n                try fileManager.removeItem(at: fileURL)\n                assert(!fileManager.fileExists(atPath: fileURL.path()))\n            }\n            assert(fileManager.fileExists(atPath: modelLocalBaseURL.appending(path: Constants.modelConfigFileName).path()))\n            switchToIndexing()\n        } catch {\n            print(error.localizedDescription)\n        }\n    }\n\n    func delete() {\n        do {\n            try fileManager.removeItem(at: modelLocalBaseURL)\n            assert(!fileManager.fileExists(atPath: modelLocalBaseURL.path()))\n            startState.requestDeleteModel(modelID: modelConfig.modelID!) // TODO: can it decouple?\n        } catch {\n            print(error.localizedDescription)\n        }\n    }\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/Views/ChatView.swift",
    "content": "//\n//  ChatView.swift\n//  MLCChat\n//\n\nimport SwiftUI\nimport GameController\n\nstruct ChatView: View {\n    @EnvironmentObject private var chatState: ChatState\n    @Environment(\\.scenePhase) var scenePhase\n    @State private var inputMessage: String = \"\"\n    @FocusState private var inputIsFocused: Bool\n    @Environment(\\.dismiss) private var dismiss\n    @Namespace private var messagesBottomID\n\n    // vision-related properties\n    @State private var showActionSheet: Bool = false\n    @State private var showImagePicker: Bool = false\n    @State private var imageConfirmed: Bool = false\n    @State private var imageSourceType: UIImagePickerController.SourceType = .photoLibrary\n    @State private var image: UIImage?\n\n    var body: some View {\n        VStack {\n            modelInfoView\n            messagesView\n            uploadImageView\n            messageInputView\n        }\n        .navigationBarTitle(\"MLC Chat: \\(chatState.displayName)\", displayMode: .inline)\n        .navigationBarBackButtonHidden()\n        .onChange(of: scenePhase) { oldPhase, newPhase in\n            if newPhase == .background {\n                self.chatState.requestSwitchToBackground()\n            }\n        }\n        .toolbar {\n            ToolbarItem(placement: .navigationBarLeading) {\n                Button {\n                    dismiss()\n                } label: {\n                    Image(systemName: \"chevron.backward\")\n                }\n                .buttonStyle(.borderless)\n                .disabled(!chatState.isInterruptible)\n            }\n            ToolbarItem(placement: .navigationBarTrailing) {\n                Button(\"Reset\") {\n                    image = nil\n                    imageConfirmed = false\n                    chatState.requestResetChat()\n                }\n                .padding()\n                .disabled(!chatState.isResettable)\n            }\n        }\n\n    }\n}\n\nprivate extension ChatView {\n    var modelInfoView: some View {\n        Text(chatState.infoText)\n            .multilineTextAlignment(.center)\n            .opacity(0.5)\n            .listRowSeparator(.hidden)\n    }\n\n    var messagesView: some View {\n        ScrollViewReader { scrollViewProxy in\n            ScrollView {\n                VStack {\n                    let messageCount = chatState.displayMessages.count\n                    let hasSystemMessage = messageCount > 0 && chatState.displayMessages[0].role == MessageRole.assistant\n                    let startIndex = hasSystemMessage ? 1 : 0\n\n                    // display the system message\n                    if hasSystemMessage {\n                        MessageView(role: chatState.displayMessages[0].role, message: chatState.displayMessages[0].message, isMarkdownSupported: false)\n                    }\n\n                    // display image\n                    if let image, imageConfirmed {\n                        ImageView(image: image)\n                    }\n\n                    // display conversations\n                    ForEach(chatState.displayMessages[startIndex...], id: \\.id) { message in\n                        MessageView(role: message.role, message: message.message)\n                    }\n                    HStack { EmptyView() }\n                        .id(messagesBottomID)\n                }\n            }\n            .onChange(of: chatState.displayMessages) { _ in\n                withAnimation {\n                    scrollViewProxy.scrollTo(messagesBottomID, anchor: .bottom)\n                }\n            }\n        }\n    }\n\n    @ViewBuilder\n    var uploadImageView: some View {\n        if chatState.legacyUseImage && !imageConfirmed {\n            if image == nil {\n                Button(\"Upload picture to chat\") {\n                    showActionSheet = true\n                }\n                .actionSheet(isPresented: $showActionSheet) {\n                    ActionSheet(title: Text(\"Choose from\"), buttons: [\n                        .default(Text(\"Photo Library\")) {\n                            showImagePicker = true\n                            imageSourceType = .photoLibrary\n                        },\n                        .default(Text(\"Camera\")) {\n                            showImagePicker = true\n                            imageSourceType = .camera\n                        },\n                        .cancel()\n                    ])\n                }\n                .sheet(isPresented: $showImagePicker) {\n                    ImagePicker(image: $image,\n                                showImagePicker: $showImagePicker,\n                                imageSourceType: imageSourceType)\n                }\n                .disabled(!chatState.isUploadable)\n            } else {\n                VStack {\n                    if let image {\n                        Image(uiImage: image)\n                            .resizable()\n                            .frame(width: 300, height: 300)\n\n                        HStack {\n                            Button(\"Undo\") {\n                                self.image = nil\n                            }\n                            .padding()\n\n                            Button(\"Submit\") {\n                                imageConfirmed = true\n                            }\n                            .padding()\n                        }\n                    }\n                }\n            }\n        }\n    }\n\n    var messageInputView: some View {\n        HStack {\n            TextField(\"Inputs...\", text: $inputMessage, axis: .vertical)\n                .textFieldStyle(RoundedBorderTextFieldStyle())\n                .frame(minHeight: CGFloat(30))\n                .focused($inputIsFocused)\n                .onSubmit {\n                    let isKeyboardConnected = GCKeyboard.coalesced != nil\n                    if isKeyboardConnected {\n                        send()\n                    }\n                }\n            Button(\"Send\") {\n                send()\n            }\n            .bold()\n            .disabled(!(chatState.isChattable && inputMessage != \"\"))\n        }\n        .frame(minHeight: CGFloat(70))\n        .padding()\n    }\n\n    func send() {\n        inputIsFocused = false\n        chatState.requestGenerate(prompt: inputMessage)\n        inputMessage = \"\"\n    }\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/Views/ImageProcessing.swift",
    "content": "//\n//  ImageProcessing.swift\n//  MLCChat\n//\n//  Created by Kathryn Chen on 7/8/23.\n//\n\nimport Foundation\nimport SwiftUI\nimport UIKit\n\n// adapted from Mohammad Azam: https://github.com/azamsharp/SwiftUICamera\n// delegate task to the coordinator to produce the image\nstruct ImagePicker : UIViewControllerRepresentable {\n    typealias UIViewControllerType = UIImagePickerController\n    typealias Coordinator = ImagePickerCoordinator\n\n    @Binding var image: UIImage?\n    @Binding var showImagePicker: Bool\n    var imageSourceType: UIImagePickerController.SourceType = .photoLibrary\n\n    func makeCoordinator() -> ImagePicker.Coordinator {\n        return ImagePickerCoordinator(image: $image, showImagePicker: $showImagePicker)\n    }\n\n    func makeUIViewController(context: UIViewControllerRepresentableContext<ImagePicker>) -> UIImagePickerController {\n        let picker = UIImagePickerController()\n        picker.sourceType = imageSourceType\n        picker.delegate = context.coordinator\n        return picker\n    }\n\n    func updateUIViewController(_ uiViewController: UIImagePickerController, context: UIViewControllerRepresentableContext<ImagePicker>) {}\n}\n\n// image picker coordinator handling selecting from library or taking a photo\nclass ImagePickerCoordinator: NSObject, UINavigationControllerDelegate, UIImagePickerControllerDelegate {\n    @Binding var image: UIImage?\n    @Binding var showImagePicker: Bool\n\n    init(image: Binding<UIImage?>, showImagePicker: Binding<Bool>) {\n        _image = image\n        _showImagePicker = showImagePicker\n    }\n\n    func imagePickerController(_ picker: UIImagePickerController, didFinishPickingMediaWithInfo info: [UIImagePickerController.InfoKey : Any]) {\n        if let optionalImage = info[UIImagePickerController.InfoKey.originalImage] as? UIImage {\n            image = optionalImage\n            showImagePicker = false\n        }\n    }\n\n    func imagePickerControllerDidCancel(_ picker: UIImagePickerController) {\n        showImagePicker = false\n    }\n}\n\n// resize the input image to given width and height\nfunc resizeImage(image: UIImage, width: Int, height: Int) -> UIImage {\n    let shape = CGSize(width: width, height: height)\n    UIGraphicsBeginImageContextWithOptions(shape, true, 0.0)\n    image.draw(in: CGRect(x: 0, y: 0, width: width, height: height))\n    let resizedImage: UIImage? = UIGraphicsGetImageFromCurrentImageContext()\n        UIGraphicsEndImageContext()\n    return resizedImage ?? image\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/Views/MessageView.swift",
    "content": "//\n//  MessageView.swift\n//  MLCChat\n//\n\nimport SwiftUI\nimport MarkdownUI\n\nstruct MessageView: View {\n    let role: MessageRole;\n    let message: String\n    let isMarkdownSupported: Bool\n\n    @State private var showMarkdown: Bool\n\n    init(role: MessageRole, message: String, isMarkdownSupported: Bool = true) {\n        self.role = role\n        self.message = message\n        self.isMarkdownSupported = isMarkdownSupported\n        _showMarkdown = State(initialValue: isMarkdownSupported)\n    }\n    var body: some View {\n        let textColor = role.isUser ? Color.white : Color(UIColor.label)\n        let background = role.isUser ? Color.blue : Color(UIColor.secondarySystemBackground)\n\n        HStack {\n            if role.isUser {\n                Spacer()\n                Text(message)\n                    .padding(10)\n                    .foregroundColor(textColor)\n                    .background(background)\n                    .cornerRadius(10)\n                    .textSelection(.enabled)\n            }\n            if !role.isUser {\n                VStack(alignment: .leading) {\n                    // Toggle switch to show/hide Markdown\n                    if(isMarkdownSupported){\n                        Toggle(isOn: $showMarkdown) {\n                            Text(\"Show as Markdown\")\n                                .font(.footnote)\n                                .foregroundColor(.blue)\n                        }\n                        .padding(.bottom, 10)\n                    }\n\n                    // Conditionally display Text or Markdown\n                    if showMarkdown {\n                        Markdown {\n                            message\n                        }\n                        .markdownTheme(.gitHub)\n                        .padding(10)\n                        .foregroundColor(textColor)\n                        .background(background)\n                        .cornerRadius(10)\n                        .textSelection(.enabled)\n                    } else {\n                        Text(message)\n                            .padding(10)\n                            .foregroundColor(textColor)\n                            .background(background)\n                            .cornerRadius(10)\n                            .textSelection(.enabled)\n                    }\n                }\n                Spacer()\n            }\n        }\n        .padding()\n        .listRowSeparator(.hidden)\n    }\n}\n\nstruct ImageView: View {\n    let image: UIImage\n\n    var body: some View {\n        let background = Color.blue\n        HStack {\n            Spacer()\n            Image(uiImage: image)\n                .resizable()\n                .frame(width: 150, height: 150)\n                .padding(15)\n                .background(background)\n                .cornerRadius(20)\n        }\n        .padding()\n        .listRowSeparator(.hidden)\n    }\n}\n\nstruct MessageView_Previews: PreviewProvider {\n    static var previews: some View {\n        NavigationView {\n            VStack (spacing: 0){\n                ScrollView {\n                    MessageView(role: MessageRole.user, message: \"Message 1\")\n                    MessageView(role: MessageRole.assistant, message: \"Message 2\")\n                    MessageView(role: MessageRole.user, message: \"Message 3\")\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/Views/ModelView.swift",
    "content": "//\n//  ModelView.swift\n//  MLCChat\n//\n//  Created by Yaxing Cai on 5/14/23.\n//\n\nimport SwiftUI\n\nstruct ModelView: View {\n    @EnvironmentObject private var modelState: ModelState\n    @EnvironmentObject private var chatState: ChatState\n    @Binding var isRemoving: Bool\n\n    @State private var isShowingDeletionConfirmation: Bool = false\n\n    var body: some View {\n        VStack(alignment: .leading) {\n            if (modelState.modelDownloadState == .finished) {\n                NavigationLink(destination:\n                                ChatView()\n                    .environmentObject(chatState)\n                    .onAppear {\n                        modelState.startChat(chatState: chatState)\n                    }\n                ) {\n                    HStack {\n                        Text(modelState.modelConfig.modelID!)\n                        Spacer()\n                        if chatState.isCurrentModel(modelID: modelState.modelConfig.modelID!) {\n                            Image(systemName: \"checkmark\").foregroundColor(.blue)\n                        }\n                    }\n                }\n                .buttonStyle(.borderless)\n            } else {\n                Text(modelState.modelConfig.modelID!).opacity(0.5)\n            }\n            HStack{\n                if modelState.modelDownloadState != .finished || isRemoving {\n                    ProgressView(value: Double(modelState.progress) / Double(modelState.total))\n                        .progressViewStyle(.linear)\n                }\n\n                if (modelState.modelDownloadState == .paused) {\n                    Button {\n                        modelState.handleStart()\n                    } label: {\n                        Image(systemName: \"icloud.and.arrow.down\")\n                    }\n                    .buttonStyle(.borderless)\n                } else if (modelState.modelDownloadState == .downloading) {\n                    Button {\n                        modelState.handlePause()\n                    } label: {\n                        Image(systemName: \"stop.circle\")\n                    }\n                    .buttonStyle(.borderless)\n                } else if (modelState.modelDownloadState == .failed) {\n                    Image(systemName: \"exclamationmark.triangle\")\n                        .foregroundColor(.red)\n                }\n\n                if isRemoving {\n                    Button(role: .destructive) {\n                        isShowingDeletionConfirmation = true\n                    } label: {\n                        Image(systemName: \"trash\")\n                    }\n                    .confirmationDialog(\"Delete Model\", isPresented: $isShowingDeletionConfirmation) {\n                        Button(\"Delete Model\", role: .destructive) {\n                            modelState.handleDelete()\n                        }\n                        .disabled(\n                            modelState.modelDownloadState != .downloading &&\n                            modelState.modelDownloadState != .paused &&\n                            modelState.modelDownloadState != .finished &&\n                            modelState.modelDownloadState != .failed)\n                        Button(\"Clear Data\") {\n                            modelState.handleClear()\n                        }\n                        .disabled(\n                            modelState.modelDownloadState != .downloading &&\n                            modelState.modelDownloadState != .paused &&\n                            modelState.modelDownloadState != .finished)\n                        Button(\"Cancel\", role: .cancel) {\n                            isShowingDeletionConfirmation = false\n                        }\n                    } message: {\n                        Text(\"Delete model will delete the all files with model config, and delete the entry in list. \\n Clear model will keep the model config only, and keep the entry in list for future re-downloading.\")\n                    }\n                    .buttonStyle(.borderless)\n                }\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat/Views/StartView.swift",
    "content": "//\n//  DownloadView.swift\n//  MLCChat\n//\n//  Created by Yaxing Cai on 5/11/23.\n//\n\nimport SwiftUI\n\nstruct StartView: View {\n    @EnvironmentObject private var appState: AppState\n    @State private var isAdding: Bool = false\n    @State private var isRemoving: Bool = false\n    @State private var inputModelUrl: String = \"\"\n\n    var body: some View {\n        NavigationStack {\n            List{\n                Section(header: Text(\"Models\")) {\n                    ForEach(appState.models) { modelState in\n                        ModelView(isRemoving: $isRemoving)\n                            .environmentObject(modelState)\n                            .environmentObject(appState.chatState)\n                    }\n                    if !isRemoving {\n                        Button(\"Edit model\") {\n                            isRemoving = true\n                        }\n                        .buttonStyle(.borderless)\n                    } else {\n                        Button(\"Cancel edit model\") {\n                            isRemoving = false\n                        }\n                        .buttonStyle(.borderless)\n                    }\n                }\n            }\n            .navigationTitle(\"MLC Chat\")\n            .alert(\"Error\", isPresented: $appState.alertDisplayed) {\n                Button(\"OK\") { }\n            } message: {\n                Text(appState.alertMessage)\n            }\n        }\n    }\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat.xcodeproj/project.pbxproj",
    "content": "// !$*UTF8*$!\n{\n    archiveVersion = 1;\n    classes = {\n    };\n    objectVersion = 60;\n    objects = {\n\n/* Begin PBXBuildFile section */\n        1453A4CF2A1354B9001B909F /* StartView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CA2A1354B9001B909F /* StartView.swift */; };\n        1453A4D02A1354B9001B909F /* ModelView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CB2A1354B9001B909F /* ModelView.swift */; };\n        1453A4D12A1354B9001B909F /* AppState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CC2A1354B9001B909F /* AppState.swift */; };\n        1453A4D22A1354B9001B909F /* ModelConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CD2A1354B9001B909F /* ModelConfig.swift */; };\n        1453A4D32A1354B9001B909F /* ModelState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1453A4CE2A1354B9001B909F /* ModelState.swift */; };\n        A773CC652A5DC98200467BFE /* ImageProcessing.swift in Sources */ = {isa = PBXBuildFile; fileRef = A773CC642A5DC98200467BFE /* ImageProcessing.swift */; };\n        AEC27EFA2A85C2AC00254E67 /* ParamsConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27EF92A85C2AC00254E67 /* ParamsConfig.swift */; };\n        AEC27EFC2A85C3B000254E67 /* AppConfig.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27EFB2A85C3B000254E67 /* AppConfig.swift */; };\n        AEC27F022A86337E00254E67 /* Constants.swift in Sources */ = {isa = PBXBuildFile; fileRef = AEC27F012A86337E00254E67 /* Constants.swift */; };\n        B08647022C6D0293001A8B5E /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = B08647012C6D0293001A8B5E /* MarkdownUI */; };\n        C04105DD2BEBBEA6005A434D /* MLCSwift in Frameworks */ = {isa = PBXBuildFile; productRef = C04105DC2BEBBEA6005A434D /* MLCSwift */; };\n        C0D643B329F99A7F004DDAA4 /* MLCChatApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643B229F99A7F004DDAA4 /* MLCChatApp.swift */; };\n        C0D643B729F99A80004DDAA4 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0D643B629F99A80004DDAA4 /* Assets.xcassets */; };\n        C0D643BA29F99A80004DDAA4 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0D643B929F99A80004DDAA4 /* Preview Assets.xcassets */; };\n        C0D643C429F99B07004DDAA4 /* ChatView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643C229F99B07004DDAA4 /* ChatView.swift */; };\n        C0D643C829F99B34004DDAA4 /* MessageView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643C729F99B34004DDAA4 /* MessageView.swift */; };\n        C0DDBDF62A39103F00E9D060 /* ChatState.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0D643C029F99B07004DDAA4 /* ChatState.swift */; };\n        F3C280002BEB16ED00F1E016 /* bundle in CopyFiles */ = {isa = PBXBuildFile; fileRef = F3C27FFF2BEB16ED00F1E016 /* bundle */; };\n/* End PBXBuildFile section */\n\n/* Begin PBXCopyFilesBuildPhase section */\n        C06A74F129F9A78000BC4BE6 /* CopyFiles */ = {\n            isa = PBXCopyFilesBuildPhase;\n            buildActionMask = 2147483647;\n            dstPath = \"\";\n            dstSubfolderSpec = 7;\n            files = (\n                F3C280002BEB16ED00F1E016 /* bundle in CopyFiles */,\n            );\n            runOnlyForDeploymentPostprocessing = 0;\n        };\n        C0D643CF29F99C5D004DDAA4 /* Embed Libraries */ = {\n            isa = PBXCopyFilesBuildPhase;\n            buildActionMask = 2147483647;\n            dstPath = \"\";\n            dstSubfolderSpec = 10;\n            files = (\n            );\n            name = \"Embed Libraries\";\n            runOnlyForDeploymentPostprocessing = 0;\n        };\n/* End PBXCopyFilesBuildPhase section */\n\n/* Begin PBXFileReference section */\n        1453A4CA2A1354B9001B909F /* StartView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = StartView.swift; sourceTree = \"<group>\"; };\n        1453A4CB2A1354B9001B909F /* ModelView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ModelView.swift; sourceTree = \"<group>\"; };\n        1453A4CC2A1354B9001B909F /* AppState.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = AppState.swift; sourceTree = \"<group>\"; };\n        1453A4CD2A1354B9001B909F /* ModelConfig.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ModelConfig.swift; sourceTree = \"<group>\"; };\n        1453A4CE2A1354B9001B909F /* ModelState.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ModelState.swift; sourceTree = \"<group>\"; };\n        A773CC642A5DC98200467BFE /* ImageProcessing.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ImageProcessing.swift; sourceTree = \"<group>\"; };\n        AEC27EF92A85C2AC00254E67 /* ParamsConfig.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ParamsConfig.swift; sourceTree = \"<group>\"; };\n        AEC27EFB2A85C3B000254E67 /* AppConfig.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppConfig.swift; sourceTree = \"<group>\"; };\n        AEC27F012A86337E00254E67 /* Constants.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Constants.swift; sourceTree = \"<group>\"; };\n        C06A74E629F9A1DF00BC4BE6 /* MLCChat.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = MLCChat.entitlements; sourceTree = \"<group>\"; };\n        C0D643AF29F99A7F004DDAA4 /* MLCChat.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLCChat.app; sourceTree = BUILT_PRODUCTS_DIR; };\n        C0D643B229F99A7F004DDAA4 /* MLCChatApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLCChatApp.swift; sourceTree = \"<group>\"; };\n        C0D643B629F99A80004DDAA4 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = \"<group>\"; };\n        C0D643B929F99A80004DDAA4 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = \"Preview Assets.xcassets\"; sourceTree = \"<group>\"; };\n        C0D643C029F99B07004DDAA4 /* ChatState.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ChatState.swift; sourceTree = \"<group>\"; };\n        C0D643C229F99B07004DDAA4 /* ChatView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ChatView.swift; sourceTree = \"<group>\"; };\n        C0D643C729F99B34004DDAA4 /* MessageView.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = MessageView.swift; sourceTree = \"<group>\"; };\n        C0DDBE0B2A3BA6F800E9D060 /* MLCSwift */ = {isa = PBXFileReference; lastKnownFileType = wrapper; path = MLCSwift; sourceTree = \"<group>\"; };\n        F3C27FFF2BEB16ED00F1E016 /* bundle */ = {isa = PBXFileReference; lastKnownFileType = folder; name = bundle; path = dist/bundle; sourceTree = \"<group>\"; };\n/* End PBXFileReference section */\n\n/* Begin PBXFrameworksBuildPhase section */\n        C0D643AC29F99A7F004DDAA4 /* Frameworks */ = {\n            isa = PBXFrameworksBuildPhase;\n            buildActionMask = 2147483647;\n            files = (\n                C04105DD2BEBBEA6005A434D /* MLCSwift in Frameworks */,\n                B08647022C6D0293001A8B5E /* MarkdownUI in Frameworks */,\n            );\n            runOnlyForDeploymentPostprocessing = 0;\n        };\n/* End PBXFrameworksBuildPhase section */\n\n/* Begin PBXGroup section */\n        AEC27EF82A85C29000254E67 /* Models */ = {\n            isa = PBXGroup;\n            children = (\n                1453A4CD2A1354B9001B909F /* ModelConfig.swift */,\n                AEC27EF92A85C2AC00254E67 /* ParamsConfig.swift */,\n                AEC27EFB2A85C3B000254E67 /* AppConfig.swift */,\n            );\n            path = Models;\n            sourceTree = \"<group>\";\n        };\n        AEC27EFF2A85EE2800254E67 /* States */ = {\n            isa = PBXGroup;\n            children = (\n                1453A4CE2A1354B9001B909F /* ModelState.swift */,\n                1453A4CC2A1354B9001B909F /* AppState.swift */,\n                C0D643C029F99B07004DDAA4 /* ChatState.swift */,\n            );\n            path = States;\n            sourceTree = \"<group>\";\n        };\n        AEC27F002A86306800254E67 /* Views */ = {\n            isa = PBXGroup;\n            children = (\n                A773CC642A5DC98200467BFE /* ImageProcessing.swift */,\n                1453A4CB2A1354B9001B909F /* ModelView.swift */,\n                1453A4CA2A1354B9001B909F /* StartView.swift */,\n                C0D643C729F99B34004DDAA4 /* MessageView.swift */,\n                C0D643C229F99B07004DDAA4 /* ChatView.swift */,\n            );\n            path = Views;\n            sourceTree = \"<group>\";\n        };\n        AEC27F032A86338800254E67 /* Common */ = {\n            isa = PBXGroup;\n            children = (\n                AEC27F012A86337E00254E67 /* Constants.swift */,\n            );\n            path = Common;\n            sourceTree = \"<group>\";\n        };\n        C0D643A629F99A7F004DDAA4 = {\n            isa = PBXGroup;\n            children = (\n                F3C27FFF2BEB16ED00F1E016 /* bundle */,\n                C0DDBDF02A39068900E9D060 /* Packages */,\n                C0D643B129F99A7F004DDAA4 /* MLCChat */,\n                C0D643B029F99A7F004DDAA4 /* Products */,\n                C0D643C929F99BDA004DDAA4 /* Frameworks */,\n            );\n            sourceTree = \"<group>\";\n        };\n        C0D643B029F99A7F004DDAA4 /* Products */ = {\n            isa = PBXGroup;\n            children = (\n                C0D643AF29F99A7F004DDAA4 /* MLCChat.app */,\n            );\n            name = Products;\n            sourceTree = \"<group>\";\n        };\n        C0D643B129F99A7F004DDAA4 /* MLCChat */ = {\n            isa = PBXGroup;\n            children = (\n                AEC27F032A86338800254E67 /* Common */,\n                AEC27EF82A85C29000254E67 /* Models */,\n                AEC27EFF2A85EE2800254E67 /* States */,\n                AEC27F002A86306800254E67 /* Views */,\n                C06A74E629F9A1DF00BC4BE6 /* MLCChat.entitlements */,\n                C0D643B229F99A7F004DDAA4 /* MLCChatApp.swift */,\n                C0D643B629F99A80004DDAA4 /* Assets.xcassets */,\n                C0D643B829F99A80004DDAA4 /* Preview Content */,\n            );\n            path = MLCChat;\n            sourceTree = \"<group>\";\n        };\n        C0D643B829F99A80004DDAA4 /* Preview Content */ = {\n            isa = PBXGroup;\n            children = (\n                C0D643B929F99A80004DDAA4 /* Preview Assets.xcassets */,\n            );\n            path = \"Preview Content\";\n            sourceTree = \"<group>\";\n        };\n        C0D643C929F99BDA004DDAA4 /* Frameworks */ = {\n            isa = PBXGroup;\n            children = (\n            );\n            name = Frameworks;\n            sourceTree = \"<group>\";\n        };\n        C0DDBDF02A39068900E9D060 /* Packages */ = {\n            isa = PBXGroup;\n            children = (\n                C0DDBE0B2A3BA6F800E9D060 /* MLCSwift */,\n            );\n            name = Packages;\n            sourceTree = \"<group>\";\n        };\n/* End PBXGroup section */\n\n/* Begin PBXNativeTarget section */\n        C0D643AE29F99A7F004DDAA4 /* MLCChat */ = {\n            isa = PBXNativeTarget;\n            buildConfigurationList = C0D643BD29F99A80004DDAA4 /* Build configuration list for PBXNativeTarget \"MLCChat\" */;\n            buildPhases = (\n                C0D643AB29F99A7F004DDAA4 /* Sources */,\n                C0D643AC29F99A7F004DDAA4 /* Frameworks */,\n                C0D643AD29F99A7F004DDAA4 /* Resources */,\n                C0D643CF29F99C5D004DDAA4 /* Embed Libraries */,\n                C06A74F129F9A78000BC4BE6 /* CopyFiles */,\n            );\n            buildRules = (\n            );\n            dependencies = (\n            );\n            name = MLCChat;\n            packageProductDependencies = (\n                C04105DC2BEBBEA6005A434D /* MLCSwift */,\n                B08647012C6D0293001A8B5E /* MarkdownUI */,\n            );\n            productName = MLCChat;\n            productReference = C0D643AF29F99A7F004DDAA4 /* MLCChat.app */;\n            productType = \"com.apple.product-type.application\";\n        };\n/* End PBXNativeTarget section */\n\n/* Begin PBXProject section */\n        C0D643A729F99A7F004DDAA4 /* Project object */ = {\n            isa = PBXProject;\n            attributes = {\n                BuildIndependentTargetsInParallel = 1;\n                LastSwiftUpdateCheck = 1430;\n                LastUpgradeCheck = 1430;\n                TargetAttributes = {\n                    C0D643AE29F99A7F004DDAA4 = {\n                        CreatedOnToolsVersion = 14.3;\n                        LastSwiftMigration = 1430;\n                    };\n                };\n            };\n            buildConfigurationList = C0D643AA29F99A7F004DDAA4 /* Build configuration list for PBXProject \"MLCChat\" */;\n            compatibilityVersion = \"Xcode 14.0\";\n            developmentRegion = en;\n            hasScannedForEncodings = 0;\n            knownRegions = (\n                en,\n                Base,\n            );\n            mainGroup = C0D643A629F99A7F004DDAA4;\n            packageReferences = (\n                C04105DB2BEBBEA6005A434D /* XCLocalSwiftPackageReference \"../MLCSwift\" */,\n                B08647002C6D0293001A8B5E /* XCRemoteSwiftPackageReference \"swift-markdown-ui\" */,\n            );\n            productRefGroup = C0D643B029F99A7F004DDAA4 /* Products */;\n            projectDirPath = \"\";\n            projectRoot = \"\";\n            targets = (\n                C0D643AE29F99A7F004DDAA4 /* MLCChat */,\n            );\n        };\n/* End PBXProject section */\n\n/* Begin PBXResourcesBuildPhase section */\n        C0D643AD29F99A7F004DDAA4 /* Resources */ = {\n            isa = PBXResourcesBuildPhase;\n            buildActionMask = 2147483647;\n            files = (\n                C0D643BA29F99A80004DDAA4 /* Preview Assets.xcassets in Resources */,\n                C0D643B729F99A80004DDAA4 /* Assets.xcassets in Resources */,\n            );\n            runOnlyForDeploymentPostprocessing = 0;\n        };\n/* End PBXResourcesBuildPhase section */\n\n/* Begin PBXSourcesBuildPhase section */\n        C0D643AB29F99A7F004DDAA4 /* Sources */ = {\n            isa = PBXSourcesBuildPhase;\n            buildActionMask = 2147483647;\n            files = (\n                A773CC652A5DC98200467BFE /* ImageProcessing.swift in Sources */,\n                1453A4D12A1354B9001B909F /* AppState.swift in Sources */,\n                C0D643B329F99A7F004DDAA4 /* MLCChatApp.swift in Sources */,\n                C0DDBDF62A39103F00E9D060 /* ChatState.swift in Sources */,\n                C0D643C429F99B07004DDAA4 /* ChatView.swift in Sources */,\n                1453A4D32A1354B9001B909F /* ModelState.swift in Sources */,\n                C0D643C829F99B34004DDAA4 /* MessageView.swift in Sources */,\n                1453A4D22A1354B9001B909F /* ModelConfig.swift in Sources */,\n                AEC27EFA2A85C2AC00254E67 /* ParamsConfig.swift in Sources */,\n                AEC27EFC2A85C3B000254E67 /* AppConfig.swift in Sources */,\n                AEC27F022A86337E00254E67 /* Constants.swift in Sources */,\n                1453A4D02A1354B9001B909F /* ModelView.swift in Sources */,\n                1453A4CF2A1354B9001B909F /* StartView.swift in Sources */,\n            );\n            runOnlyForDeploymentPostprocessing = 0;\n        };\n/* End PBXSourcesBuildPhase section */\n\n/* Begin XCBuildConfiguration section */\n        C0D643BB29F99A80004DDAA4 /* Debug */ = {\n            isa = XCBuildConfiguration;\n            buildSettings = {\n                ALWAYS_SEARCH_USER_PATHS = NO;\n                CLANG_ANALYZER_NONNULL = YES;\n                CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;\n                CLANG_CXX_LANGUAGE_STANDARD = \"gnu++20\";\n                CLANG_ENABLE_MODULES = YES;\n                CLANG_ENABLE_OBJC_ARC = YES;\n                CLANG_ENABLE_OBJC_WEAK = YES;\n                CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;\n                CLANG_WARN_BOOL_CONVERSION = YES;\n                CLANG_WARN_COMMA = YES;\n                CLANG_WARN_CONSTANT_CONVERSION = YES;\n                CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;\n                CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;\n                CLANG_WARN_DOCUMENTATION_COMMENTS = YES;\n                CLANG_WARN_EMPTY_BODY = YES;\n                CLANG_WARN_ENUM_CONVERSION = YES;\n                CLANG_WARN_INFINITE_RECURSION = YES;\n                CLANG_WARN_INT_CONVERSION = YES;\n                CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;\n                CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;\n                CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;\n                CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;\n                CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;\n                CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;\n                CLANG_WARN_STRICT_PROTOTYPES = YES;\n                CLANG_WARN_SUSPICIOUS_MOVE = YES;\n                CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;\n                CLANG_WARN_UNREACHABLE_CODE = YES;\n                CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;\n                COPY_PHASE_STRIP = NO;\n                DEBUG_INFORMATION_FORMAT = dwarf;\n                ENABLE_STRICT_OBJC_MSGSEND = YES;\n                ENABLE_TESTABILITY = YES;\n                GCC_C_LANGUAGE_STANDARD = gnu11;\n                GCC_DYNAMIC_NO_PIC = NO;\n                GCC_NO_COMMON_BLOCKS = YES;\n                GCC_OPTIMIZATION_LEVEL = 0;\n                GCC_PREPROCESSOR_DEFINITIONS = (\n                    \"DEBUG=1\",\n                    \"$(inherited)\",\n                );\n                GCC_WARN_64_TO_32_BIT_CONVERSION = YES;\n                GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;\n                GCC_WARN_UNDECLARED_SELECTOR = YES;\n                GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;\n                GCC_WARN_UNUSED_FUNCTION = YES;\n                GCC_WARN_UNUSED_VARIABLE = YES;\n                IPHONEOS_DEPLOYMENT_TARGET = 16.0;\n                MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;\n                MTL_FAST_MATH = YES;\n                ONLY_ACTIVE_ARCH = YES;\n                SDKROOT = iphoneos;\n                SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG;\n                SWIFT_OPTIMIZATION_LEVEL = \"-Onone\";\n            };\n            name = Debug;\n        };\n        C0D643BC29F99A80004DDAA4 /* Release */ = {\n            isa = XCBuildConfiguration;\n            buildSettings = {\n                ALWAYS_SEARCH_USER_PATHS = NO;\n                CLANG_ANALYZER_NONNULL = YES;\n                CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;\n                CLANG_CXX_LANGUAGE_STANDARD = \"gnu++20\";\n                CLANG_ENABLE_MODULES = YES;\n                CLANG_ENABLE_OBJC_ARC = YES;\n                CLANG_ENABLE_OBJC_WEAK = YES;\n                CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;\n                CLANG_WARN_BOOL_CONVERSION = YES;\n                CLANG_WARN_COMMA = YES;\n                CLANG_WARN_CONSTANT_CONVERSION = YES;\n                CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;\n                CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;\n                CLANG_WARN_DOCUMENTATION_COMMENTS = YES;\n                CLANG_WARN_EMPTY_BODY = YES;\n                CLANG_WARN_ENUM_CONVERSION = YES;\n                CLANG_WARN_INFINITE_RECURSION = YES;\n                CLANG_WARN_INT_CONVERSION = YES;\n                CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;\n                CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;\n                CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;\n                CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;\n                CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;\n                CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;\n                CLANG_WARN_STRICT_PROTOTYPES = YES;\n                CLANG_WARN_SUSPICIOUS_MOVE = YES;\n                CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;\n                CLANG_WARN_UNREACHABLE_CODE = YES;\n                CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;\n                COPY_PHASE_STRIP = NO;\n                DEBUG_INFORMATION_FORMAT = \"dwarf-with-dsym\";\n                ENABLE_NS_ASSERTIONS = NO;\n                ENABLE_STRICT_OBJC_MSGSEND = YES;\n                GCC_C_LANGUAGE_STANDARD = gnu11;\n                GCC_NO_COMMON_BLOCKS = YES;\n                GCC_WARN_64_TO_32_BIT_CONVERSION = YES;\n                GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;\n                GCC_WARN_UNDECLARED_SELECTOR = YES;\n                GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;\n                GCC_WARN_UNUSED_FUNCTION = YES;\n                GCC_WARN_UNUSED_VARIABLE = YES;\n                IPHONEOS_DEPLOYMENT_TARGET = 16.0;\n                MTL_ENABLE_DEBUG_INFO = NO;\n                MTL_FAST_MATH = YES;\n                SDKROOT = iphoneos;\n                SWIFT_COMPILATION_MODE = wholemodule;\n                SWIFT_OPTIMIZATION_LEVEL = \"-O\";\n                VALIDATE_PRODUCT = YES;\n            };\n            name = Release;\n        };\n        C0D643BE29F99A80004DDAA4 /* Debug */ = {\n            isa = XCBuildConfiguration;\n            buildSettings = {\n                ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;\n                ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;\n                CLANG_ENABLE_MODULES = YES;\n                CODE_SIGN_ENTITLEMENTS = MLCChat/MLCChat.entitlements;\n                CODE_SIGN_IDENTITY = \"Apple Development\";\n                CODE_SIGN_STYLE = Automatic;\n                CURRENT_PROJECT_VERSION = 1;\n                DEVELOPMENT_ASSET_PATHS = \"\\\"MLCChat/Preview Content\\\"\";\n                DEVELOPMENT_TEAM = 3FR42MXLK9;\n                ENABLE_PREVIEWS = YES;\n                GENERATE_INFOPLIST_FILE = YES;\n                \"HEADER_SEARCH_PATHS[arch=*]\" = \"\";\n                INFOPLIST_FILE = MLCChat/Info.plist;\n                INFOPLIST_KEY_LSApplicationCategoryType = \"public.app-category.productivity\";\n                INFOPLIST_KEY_NSCameraUsageDescription = \"This app requires usage of camera to function properly.\";\n                INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;\n                INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;\n                INFOPLIST_KEY_UILaunchScreen_Generation = YES;\n                INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = \"UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight\";\n                INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = \"UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight\";\n                IPHONEOS_DEPLOYMENT_TARGET = 17.0;\n                LD_RUNPATH_SEARCH_PATHS = (\n                    \"$(inherited)\",\n                    \"@executable_path/Frameworks\",\n                );\n                LIBRARY_SEARCH_PATHS = (\n                    \"$(inherited)\",\n                    \"$(PROJECT_DIR)/dist/lib\",\n                );\n                MARKETING_VERSION = 1.6;\n                OTHER_LDFLAGS = (\n                    \"-Wl,-all_load\",\n                    \"-lmodel_iphone\",\n                    \"-lmlc_llm\",\n                    \"-ltvm_ffi_static\",\n                    \"-ltvm_runtime\",\n                    \"-ltokenizers_cpp\",\n                    \"-lsentencepiece\",\n                    \"-ltokenizers_c\",\n                );\n                PRODUCT_BUNDLE_IDENTIFIER = mlc.Chat;\n                PRODUCT_NAME = \"$(TARGET_NAME)\";\n                PROVISIONING_PROFILE_SPECIFIER = \"\";\n                SWIFT_EMIT_LOC_STRINGS = YES;\n                SWIFT_OBJC_BRIDGING_HEADER = \"\";\n                SWIFT_OPTIMIZATION_LEVEL = \"-Onone\";\n                SWIFT_VERSION = 5.0;\n                TARGETED_DEVICE_FAMILY = \"1,2\";\n            };\n            name = Debug;\n        };\n        C0D643BF29F99A80004DDAA4 /* Release */ = {\n            isa = XCBuildConfiguration;\n            buildSettings = {\n                ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;\n                ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;\n                CLANG_ENABLE_MODULES = YES;\n                CODE_SIGN_ENTITLEMENTS = MLCChat/MLCChat.entitlements;\n                CODE_SIGN_IDENTITY = \"Apple Development\";\n                CODE_SIGN_STYLE = Automatic;\n                CURRENT_PROJECT_VERSION = 1;\n                DEVELOPMENT_ASSET_PATHS = \"\\\"MLCChat/Preview Content\\\"\";\n                DEVELOPMENT_TEAM = 3FR42MXLK9;\n                ENABLE_PREVIEWS = YES;\n                GENERATE_INFOPLIST_FILE = YES;\n                \"HEADER_SEARCH_PATHS[arch=*]\" = \"\";\n                INFOPLIST_FILE = MLCChat/Info.plist;\n                INFOPLIST_KEY_LSApplicationCategoryType = \"public.app-category.productivity\";\n                INFOPLIST_KEY_NSCameraUsageDescription = \"This app requires usage of camera to function properly.\";\n                INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;\n                INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;\n                INFOPLIST_KEY_UILaunchScreen_Generation = YES;\n                INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = \"UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight\";\n                INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = \"UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight\";\n                IPHONEOS_DEPLOYMENT_TARGET = 17.0;\n                LD_RUNPATH_SEARCH_PATHS = (\n                    \"$(inherited)\",\n                    \"@executable_path/Frameworks\",\n                );\n                LIBRARY_SEARCH_PATHS = (\n                    \"$(inherited)\",\n                    \"$(PROJECT_DIR)/dist/lib\",\n                );\n                MARKETING_VERSION = 1.6;\n                OTHER_LDFLAGS = (\n                    \"-Wl,-all_load\",\n                    \"-lmodel_iphone\",\n                    \"-lmlc_llm\",\n                    \"-ltvm_ffi_static\",\n                    \"-ltvm_runtime\",\n                    \"-ltokenizers_cpp\",\n                    \"-lsentencepiece\",\n                    \"-ltokenizers_c\",\n                );\n                PRODUCT_BUNDLE_IDENTIFIER = mlc.Chat;\n                PRODUCT_NAME = \"$(TARGET_NAME)\";\n                PROVISIONING_PROFILE_SPECIFIER = \"\";\n                SWIFT_EMIT_LOC_STRINGS = YES;\n                SWIFT_OBJC_BRIDGING_HEADER = \"\";\n                SWIFT_VERSION = 5.0;\n                TARGETED_DEVICE_FAMILY = \"1,2\";\n            };\n            name = Release;\n        };\n/* End XCBuildConfiguration section */\n\n/* Begin XCConfigurationList section */\n        C0D643AA29F99A7F004DDAA4 /* Build configuration list for PBXProject \"MLCChat\" */ = {\n            isa = XCConfigurationList;\n            buildConfigurations = (\n                C0D643BB29F99A80004DDAA4 /* Debug */,\n                C0D643BC29F99A80004DDAA4 /* Release */,\n            );\n            defaultConfigurationIsVisible = 0;\n            defaultConfigurationName = Release;\n        };\n        C0D643BD29F99A80004DDAA4 /* Build configuration list for PBXNativeTarget \"MLCChat\" */ = {\n            isa = XCConfigurationList;\n            buildConfigurations = (\n                C0D643BE29F99A80004DDAA4 /* Debug */,\n                C0D643BF29F99A80004DDAA4 /* Release */,\n            );\n            defaultConfigurationIsVisible = 0;\n            defaultConfigurationName = Release;\n        };\n/* End XCConfigurationList section */\n\n/* Begin XCLocalSwiftPackageReference section */\n        C04105DB2BEBBEA6005A434D /* XCLocalSwiftPackageReference \"../MLCSwift\" */ = {\n            isa = XCLocalSwiftPackageReference;\n            relativePath = ../MLCSwift;\n        };\n/* End XCLocalSwiftPackageReference section */\n\n/* Begin XCRemoteSwiftPackageReference section */\n        B08647002C6D0293001A8B5E /* XCRemoteSwiftPackageReference \"swift-markdown-ui\" */ = {\n            isa = XCRemoteSwiftPackageReference;\n            repositoryURL = \"https://github.com/gonzalezreal/swift-markdown-ui\";\n            requirement = {\n                kind = upToNextMajorVersion;\n                minimumVersion = 2.4.0;\n            };\n        };\n/* End XCRemoteSwiftPackageReference section */\n\n/* Begin XCSwiftPackageProductDependency section */\n        B08647012C6D0293001A8B5E /* MarkdownUI */ = {\n            isa = XCSwiftPackageProductDependency;\n            package = B08647002C6D0293001A8B5E /* XCRemoteSwiftPackageReference \"swift-markdown-ui\" */;\n            productName = MarkdownUI;\n        };\n        C04105DC2BEBBEA6005A434D /* MLCSwift */ = {\n            isa = XCSwiftPackageProductDependency;\n            productName = MLCSwift;\n        };\n/* End XCSwiftPackageProductDependency section */\n    };\n    rootObject = C0D643A729F99A7F004DDAA4 /* Project object */;\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat.xcodeproj/project.xcworkspace/contents.xcworkspacedata",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<Workspace\n   version = \"1.0\">\n   <FileRef\n      location = \"self:\">\n   </FileRef>\n</Workspace>\n"
  },
  {
    "path": "ios/MLCChat/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!DOCTYPE plist PUBLIC \"-//Apple//DTD PLIST 1.0//EN\" \"http://www.apple.com/DTDs/PropertyList-1.0.dtd\">\n<plist version=\"1.0\">\n<dict>\n    <key>IDEDidComputeMac32BitWarning</key>\n    <true/>\n</dict>\n</plist>\n"
  },
  {
    "path": "ios/MLCChat/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/WorkspaceSettings.xcsettings",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!DOCTYPE plist PUBLIC \"-//Apple//DTD PLIST 1.0//EN\" \"http://www.apple.com/DTDs/PropertyList-1.0.dtd\">\n<plist version=\"1.0\">\n<dict/>\n</plist>\n"
  },
  {
    "path": "ios/MLCChat/MLCChat.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved",
    "content": "{\n  \"originHash\" : \"1fe9890dccb6a9581bfc88793e2dc9fc7b4589153f379b73d5ac1114daef8442\",\n  \"pins\" : [\n    {\n      \"identity\" : \"networkimage\",\n      \"kind\" : \"remoteSourceControl\",\n      \"location\" : \"https://github.com/gonzalezreal/NetworkImage\",\n      \"state\" : {\n        \"revision\" : \"2849f5323265386e200484b0d0f896e73c3411b9\",\n        \"version\" : \"6.0.1\"\n      }\n    },\n    {\n      \"identity\" : \"swift-cmark\",\n      \"kind\" : \"remoteSourceControl\",\n      \"location\" : \"https://github.com/swiftlang/swift-cmark\",\n      \"state\" : {\n        \"revision\" : \"5d9bdaa4228b381639fff09403e39a04926e2dbe\",\n        \"version\" : \"0.7.1\"\n      }\n    },\n    {\n      \"identity\" : \"swift-markdown-ui\",\n      \"kind\" : \"remoteSourceControl\",\n      \"location\" : \"https://github.com/gonzalezreal/swift-markdown-ui\",\n      \"state\" : {\n        \"revision\" : \"5f613358148239d0292c0cef674a3c2314737f9e\",\n        \"version\" : \"2.4.1\"\n      }\n    }\n  ],\n  \"version\" : 3\n}\n"
  },
  {
    "path": "ios/MLCChat/MLCChat.xcodeproj/xcshareddata/xcschemes/MLCChat.xcscheme",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<Scheme\n   LastUpgradeVersion = \"1430\"\n   version = \"1.7\">\n   <BuildAction\n      parallelizeBuildables = \"YES\"\n      buildImplicitDependencies = \"YES\">\n      <BuildActionEntries>\n         <BuildActionEntry\n            buildForTesting = \"YES\"\n            buildForRunning = \"YES\"\n            buildForProfiling = \"YES\"\n            buildForArchiving = \"YES\"\n            buildForAnalyzing = \"YES\">\n            <BuildableReference\n               BuildableIdentifier = \"primary\"\n               BlueprintIdentifier = \"C0D643AE29F99A7F004DDAA4\"\n               BuildableName = \"MLCChat.app\"\n               BlueprintName = \"MLCChat\"\n               ReferencedContainer = \"container:MLCChat.xcodeproj\">\n            </BuildableReference>\n         </BuildActionEntry>\n      </BuildActionEntries>\n   </BuildAction>\n   <TestAction\n      buildConfiguration = \"Debug\"\n      selectedDebuggerIdentifier = \"Xcode.DebuggerFoundation.Debugger.LLDB\"\n      selectedLauncherIdentifier = \"Xcode.DebuggerFoundation.Launcher.LLDB\"\n      shouldUseLaunchSchemeArgsEnv = \"YES\"\n      shouldAutocreateTestPlan = \"YES\">\n   </TestAction>\n   <LaunchAction\n      buildConfiguration = \"Release\"\n      selectedDebuggerIdentifier = \"Xcode.DebuggerFoundation.Debugger.LLDB\"\n      selectedLauncherIdentifier = \"Xcode.DebuggerFoundation.Launcher.LLDB\"\n      disableMainThreadChecker = \"YES\"\n      launchStyle = \"0\"\n      useCustomWorkingDirectory = \"NO\"\n      ignoresPersistentStateOnLaunch = \"NO\"\n      debugDocumentVersioning = \"YES\"\n      debugServiceExtension = \"internal\"\n      enableGPUFrameCaptureMode = \"3\"\n      enableGPUValidationMode = \"1\"\n      allowLocationSimulation = \"YES\"\n      disablePerformanceAntipatternChecker = \"YES\">\n      <BuildableProductRunnable\n         runnableDebuggingMode = \"0\">\n         <BuildableReference\n            BuildableIdentifier = \"primary\"\n            BlueprintIdentifier = \"C0D643AE29F99A7F004DDAA4\"\n            BuildableName = \"MLCChat.app\"\n            BlueprintName = \"MLCChat\"\n            ReferencedContainer = \"container:MLCChat.xcodeproj\">\n         </BuildableReference>\n      </BuildableProductRunnable>\n   </LaunchAction>\n   <ProfileAction\n      buildConfiguration = \"Release\"\n      shouldUseLaunchSchemeArgsEnv = \"YES\"\n      savedToolIdentifier = \"\"\n      useCustomWorkingDirectory = \"NO\"\n      debugDocumentVersioning = \"YES\">\n      <BuildableProductRunnable\n         runnableDebuggingMode = \"0\">\n         <BuildableReference\n            BuildableIdentifier = \"primary\"\n            BlueprintIdentifier = \"C0D643AE29F99A7F004DDAA4\"\n            BuildableName = \"MLCChat.app\"\n            BlueprintName = \"MLCChat\"\n            ReferencedContainer = \"container:MLCChat.xcodeproj\">\n         </BuildableReference>\n      </BuildableProductRunnable>\n   </ProfileAction>\n   <AnalyzeAction\n      buildConfiguration = \"Debug\">\n   </AnalyzeAction>\n   <ArchiveAction\n      buildConfiguration = \"Release\"\n      revealArchiveInOrganizer = \"YES\">\n   </ArchiveAction>\n</Scheme>\n"
  },
  {
    "path": "ios/MLCChat/README.md",
    "content": "# MLC Chat App\n\nCheckout [Documentation page](https://llm.mlc.ai/docs/deploy/ios.html) for more information.\n\n- run `mlc_llm package`\n- open the Xcode project\n"
  },
  {
    "path": "ios/MLCChat/mlc-package-config.json",
    "content": "{\n    \"device\": \"iphone\",\n    \"model_list\": [\n        {\n            \"model\": \"HF://mlc-ai/Llama-3.2-3B-Instruct-q4f16_1-MLC\",\n            \"model_id\": \"Llama-3.2-3B-Instruct-q4f16_1-MLC\",\n            \"estimated_vram_bytes\": 3000000000,\n            \"overrides\": {\n                \"prefill_chunk_size\": 128,\n                \"context_window_size\": 2048\n            },\n            \"bundle_weight\": true\n        },\n        {\n            \"model\": \"HF://mlc-ai/gemma-2-2b-it-q4f16_1-MLC\",\n            \"model_id\": \"gemma-2-2b-q4f16_1-MLC\",\n            \"estimated_vram_bytes\": 3000000000,\n            \"overrides\": {\n                \"prefill_chunk_size\": 128\n            }\n        },\n        {\n            \"model\": \"HF://mlc-ai/Phi-3.5-mini-instruct-q4f16_1-MLC\",\n            \"model_id\": \"Phi-3.5-mini-instruct-q4f16_1-MLC\",\n            \"estimated_vram_bytes\": 3043000000,\n            \"overrides\": {\n                \"prefill_chunk_size\": 128\n            }\n        },\n        {\n            \"model\": \"HF://mlc-ai/Qwen3-0.6B-q0f16-MLC\",\n            \"model_id\": \"Qwen3-0.6B-q0f16-MLC\",\n            \"estimated_vram_bytes\": 3000000000,\n            \"overrides\": {\n                \"prefill_chunk_size\": 128,\n                \"context_window_size\": 2048\n            }\n        },\n        {\n            \"model\": \"HF://mlc-ai/Qwen3-1.7B-q4f16_1-MLC\",\n            \"model_id\": \"Qwen3-1.7B-q4f16_1-MLC\",\n            \"estimated_vram_bytes\": 3000000000,\n            \"overrides\": {\n                \"prefill_chunk_size\": 128,\n                \"context_window_size\": 2048\n            }\n        }\n    ]\n}\n"
  },
  {
    "path": "ios/MLCEngineExample/MLCEngineExample/Assets.xcassets/AccentColor.colorset/Contents.json",
    "content": "{\n  \"colors\" : [\n    {\n      \"idiom\" : \"universal\"\n    }\n  ],\n  \"info\" : {\n    \"author\" : \"xcode\",\n    \"version\" : 1\n  }\n}\n"
  },
  {
    "path": "ios/MLCEngineExample/MLCEngineExample/Assets.xcassets/AppIcon.appiconset/Contents.json",
    "content": "{\n  \"images\" : [\n    {\n      \"idiom\" : \"universal\",\n      \"platform\" : \"ios\",\n      \"size\" : \"1024x1024\"\n    }\n  ],\n  \"info\" : {\n    \"author\" : \"xcode\",\n    \"version\" : 1\n  }\n}\n"
  },
  {
    "path": "ios/MLCEngineExample/MLCEngineExample/Assets.xcassets/Contents.json",
    "content": "{\n  \"info\" : {\n    \"author\" : \"xcode\",\n    \"version\" : 1\n  }\n}\n"
  },
  {
    "path": "ios/MLCEngineExample/MLCEngineExample/ContentView.swift",
    "content": "// This is a minimum example App to interact with MLC Engine\n//\n// for a complete example, take a look at the MLCChat\n\nimport SwiftUI\n\nstruct ContentView: View {\n    @EnvironmentObject private var appState: AppState\n    // simply display text on the app\n    var body: some View {\n        HStack {\n            Text(appState.displayText)\n            Spacer()\n        }\n        .padding()\n    }\n}\n\n#Preview {\n    ContentView()\n}\n"
  },
  {
    "path": "ios/MLCEngineExample/MLCEngineExample/MLCEngineExample.entitlements",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!DOCTYPE plist PUBLIC \"-//Apple//DTD PLIST 1.0//EN\" \"http://www.apple.com/DTDs/PropertyList-1.0.dtd\">\n<plist version=\"1.0\">\n<dict>\n    <key>com.apple.developer.kernel.extended-virtual-addressing</key>\n    <true/>\n    <key>com.apple.developer.kernel.increased-memory-limit</key>\n    <true/>\n</dict>\n</plist>\n"
  },
  {
    "path": "ios/MLCEngineExample/MLCEngineExample/MLCEngineExampleApp.swift",
    "content": "// This is a minimum example App to interact with MLC Engine\n// This app is mainly created with minimalism in mind for\n// example and quick testing purposes.\n//\n// To build this app, select target My Mac(Designed for iPad) and run\n// Make sure you run \"mlc_llm package\" first with \"MLCChat\"\n// replaced by \"MLCEngineExample\"\n// to ensure the \"dist/bundle\" folder populates with the right model file\n// and we have the model lib packaged correctly\nimport Foundation\nimport SwiftUI\n\nimport MLCSwift\n\nclass AppState: ObservableObject {\n    // the MLC engine instance\n    private let engine = MLCEngine()\n    // obtain the local path to store models\n    // this that stores the model files in the dist folder\n    private let bundleURL = Bundle.main.bundleURL.appending(path: \"bundle\")\n    // model path, this must match a builtin\n    // file name in prepare_params.sh\n    private let modelPath = \"Llama-3-8B-Instruct-q3f16_1-MLC\"\n    // model lib identifier of within the packaged library\n    // make sure we run \"mlc_llm package\"\n    private let modelLib = \"llama_q3f16_1\"\n\n    // this is a message to be displayed in app\n    @Published var displayText = \"\"\n\n    public func runExample() {\n        // MLCEngine is a actor that can be called in an async context\n        Task {\n            let modelLocalPath = bundleURL.appending(path: modelPath).path()\n            // Step 0: load the engine\n            await engine.reload(modelPath: modelLocalPath, modelLib: modelLib)\n\n            // run chat completion as in OpenAI API style\n            for await res in await engine.chat.completions.create(\n                messages: [\n                    ChatCompletionMessage(\n                        role: .user,\n                        content: \"What is the meaning of life?\"\n                    )\n                ],\n                stream_options: StreamOptions(include_usage: true)\n            ) {\n                // publish at main event loop\n                DispatchQueue.main.async {\n                    // parse the result content in structured form\n                    // and stream back to the display\n                    if let finalUsage = res.usage {\n                        self.displayText += \"\\n\" + (finalUsage.extra?.asTextLabel() ?? \"\")\n                    } else {\n                        self.displayText += res.choices[0].delta.content!.asText()\n                    }\n                }\n            }\n        }\n    }\n}\n\n\n@main\nstruct MLCEngineExampleApp: App {\n    private let appState = AppState()\n\n    init() {\n        // we simply run test\n        // please checkout output in console\n        appState.runExample()\n    }\n\n    var body: some Scene {\n        WindowGroup {\n            ContentView()\n                .environmentObject(appState)\n        }\n    }\n}\n"
  },
  {
    "path": "ios/MLCEngineExample/MLCEngineExample/Preview Content/Preview Assets.xcassets/Contents.json",
    "content": "{\n  \"info\" : {\n    \"author\" : \"xcode\",\n    \"version\" : 1\n  }\n}\n"
  },
  {
    "path": "ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.pbxproj",
    "content": "// !$*UTF8*$!\n{\n    archiveVersion = 1;\n    classes = {\n    };\n    objectVersion = 60;\n    objects = {\n\n/* Begin PBXBuildFile section */\n        C04105DF2BEBC61B005A434D /* MLCSwift in Frameworks */ = {isa = PBXBuildFile; productRef = C04105DE2BEBC61B005A434D /* MLCSwift */; };\n        C07094522BEBC6C4005C29FC /* bundle in Copy Files */ = {isa = PBXBuildFile; fileRef = C07094512BEBC6C4005C29FC /* bundle */; };\n        C0B37B892BE8226A00B2F80B /* MLCEngineExampleApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0B37B882BE8226A00B2F80B /* MLCEngineExampleApp.swift */; };\n        C0B37B8B2BE8226A00B2F80B /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = C0B37B8A2BE8226A00B2F80B /* ContentView.swift */; };\n        C0B37B8D2BE8226B00B2F80B /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0B37B8C2BE8226B00B2F80B /* Assets.xcassets */; };\n        C0B37B902BE8226B00B2F80B /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = C0B37B8F2BE8226B00B2F80B /* Preview Assets.xcassets */; };\n        C0B37B982BE8234D00B2F80B /* MLCSwift in Frameworks */ = {isa = PBXBuildFile; productRef = C0B37B972BE8234D00B2F80B /* MLCSwift */; };\n/* End PBXBuildFile section */\n\n/* Begin PBXCopyFilesBuildPhase section */\n        C0B37B992BE8255600B2F80B /* Copy Files */ = {\n            isa = PBXCopyFilesBuildPhase;\n            buildActionMask = 12;\n            dstPath = \"\";\n            dstSubfolderSpec = 7;\n            files = (\n                C07094522BEBC6C4005C29FC /* bundle in Copy Files */,\n            );\n            name = \"Copy Files\";\n            runOnlyForDeploymentPostprocessing = 0;\n        };\n/* End PBXCopyFilesBuildPhase section */\n\n/* Begin PBXFileReference section */\n        C07094512BEBC6C4005C29FC /* bundle */ = {isa = PBXFileReference; lastKnownFileType = folder; name = bundle; path = dist/bundle; sourceTree = \"<group>\"; };\n        C0B37B852BE8226A00B2F80B /* MLCEngineExample.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MLCEngineExample.app; sourceTree = BUILT_PRODUCTS_DIR; };\n        C0B37B882BE8226A00B2F80B /* MLCEngineExampleApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MLCEngineExampleApp.swift; sourceTree = \"<group>\"; };\n        C0B37B8A2BE8226A00B2F80B /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = \"<group>\"; };\n        C0B37B8C2BE8226B00B2F80B /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = \"<group>\"; };\n        C0B37B8F2BE8226B00B2F80B /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = \"Preview Assets.xcassets\"; sourceTree = \"<group>\"; };\n        C0B37C0C2BE8349300B2F80B /* MLCEngineExample.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = MLCEngineExample.entitlements; sourceTree = \"<group>\"; };\n/* End PBXFileReference section */\n\n/* Begin PBXFrameworksBuildPhase section */\n        C0B37B822BE8226A00B2F80B /* Frameworks */ = {\n            isa = PBXFrameworksBuildPhase;\n            buildActionMask = 2147483647;\n            files = (\n                C0B37B982BE8234D00B2F80B /* MLCSwift in Frameworks */,\n                C04105DF2BEBC61B005A434D /* MLCSwift in Frameworks */,\n            );\n            runOnlyForDeploymentPostprocessing = 0;\n        };\n/* End PBXFrameworksBuildPhase section */\n\n/* Begin PBXGroup section */\n        C0B37B7C2BE8226A00B2F80B = {\n            isa = PBXGroup;\n            children = (\n                C07094512BEBC6C4005C29FC /* bundle */,\n                C0B37B872BE8226A00B2F80B /* MLCEngineExample */,\n                C0B37B862BE8226A00B2F80B /* Products */,\n            );\n            sourceTree = \"<group>\";\n        };\n        C0B37B862BE8226A00B2F80B /* Products */ = {\n            isa = PBXGroup;\n            children = (\n                C0B37B852BE8226A00B2F80B /* MLCEngineExample.app */,\n            );\n            name = Products;\n            sourceTree = \"<group>\";\n        };\n        C0B37B872BE8226A00B2F80B /* MLCEngineExample */ = {\n            isa = PBXGroup;\n            children = (\n                C0B37C0C2BE8349300B2F80B /* MLCEngineExample.entitlements */,\n                C0B37B882BE8226A00B2F80B /* MLCEngineExampleApp.swift */,\n                C0B37B8A2BE8226A00B2F80B /* ContentView.swift */,\n                C0B37B8C2BE8226B00B2F80B /* Assets.xcassets */,\n                C0B37B8E2BE8226B00B2F80B /* Preview Content */,\n            );\n            path = MLCEngineExample;\n            sourceTree = \"<group>\";\n        };\n        C0B37B8E2BE8226B00B2F80B /* Preview Content */ = {\n            isa = PBXGroup;\n            children = (\n                C0B37B8F2BE8226B00B2F80B /* Preview Assets.xcassets */,\n            );\n            path = \"Preview Content\";\n            sourceTree = \"<group>\";\n        };\n/* End PBXGroup section */\n\n/* Begin PBXNativeTarget section */\n        C0B37B842BE8226A00B2F80B /* MLCEngineExample */ = {\n            isa = PBXNativeTarget;\n            buildConfigurationList = C0B37B932BE8226B00B2F80B /* Build configuration list for PBXNativeTarget \"MLCEngineExample\" */;\n            buildPhases = (\n                C0B37B812BE8226A00B2F80B /* Sources */,\n                C0B37B822BE8226A00B2F80B /* Frameworks */,\n                C0B37B832BE8226A00B2F80B /* Resources */,\n                C0B37B992BE8255600B2F80B /* Copy Files */,\n            );\n            buildRules = (\n            );\n            dependencies = (\n            );\n            name = MLCEngineExample;\n            packageProductDependencies = (\n                C0B37B972BE8234D00B2F80B /* MLCSwift */,\n                C04105DE2BEBC61B005A434D /* MLCSwift */,\n            );\n            productName = MLCEngineExample;\n            productReference = C0B37B852BE8226A00B2F80B /* MLCEngineExample.app */;\n            productType = \"com.apple.product-type.application\";\n        };\n/* End PBXNativeTarget section */\n\n/* Begin PBXProject section */\n        C0B37B7D2BE8226A00B2F80B /* Project object */ = {\n            isa = PBXProject;\n            attributes = {\n                BuildIndependentTargetsInParallel = 1;\n                LastSwiftUpdateCheck = 1530;\n                LastUpgradeCheck = 1530;\n                TargetAttributes = {\n                    C0B37B842BE8226A00B2F80B = {\n                        CreatedOnToolsVersion = 15.3;\n                    };\n                };\n            };\n            buildConfigurationList = C0B37B802BE8226A00B2F80B /* Build configuration list for PBXProject \"MLCEngineExample\" */;\n            compatibilityVersion = \"Xcode 14.0\";\n            developmentRegion = en;\n            hasScannedForEncodings = 0;\n            knownRegions = (\n                en,\n                Base,\n            );\n            mainGroup = C0B37B7C2BE8226A00B2F80B;\n            packageReferences = (\n                C0B37B962BE8234D00B2F80B /* XCLocalSwiftPackageReference \"../MLCSwift\" */,\n            );\n            productRefGroup = C0B37B862BE8226A00B2F80B /* Products */;\n            projectDirPath = \"\";\n            projectRoot = \"\";\n            targets = (\n                C0B37B842BE8226A00B2F80B /* MLCEngineExample */,\n            );\n        };\n/* End PBXProject section */\n\n/* Begin PBXResourcesBuildPhase section */\n        C0B37B832BE8226A00B2F80B /* Resources */ = {\n            isa = PBXResourcesBuildPhase;\n            buildActionMask = 2147483647;\n            files = (\n                C0B37B902BE8226B00B2F80B /* Preview Assets.xcassets in Resources */,\n                C0B37B8D2BE8226B00B2F80B /* Assets.xcassets in Resources */,\n            );\n            runOnlyForDeploymentPostprocessing = 0;\n        };\n/* End PBXResourcesBuildPhase section */\n\n/* Begin PBXSourcesBuildPhase section */\n        C0B37B812BE8226A00B2F80B /* Sources */ = {\n            isa = PBXSourcesBuildPhase;\n            buildActionMask = 2147483647;\n            files = (\n                C0B37B8B2BE8226A00B2F80B /* ContentView.swift in Sources */,\n                C0B37B892BE8226A00B2F80B /* MLCEngineExampleApp.swift in Sources */,\n            );\n            runOnlyForDeploymentPostprocessing = 0;\n        };\n/* End PBXSourcesBuildPhase section */\n\n/* Begin XCBuildConfiguration section */\n        C0B37B912BE8226B00B2F80B /* Debug */ = {\n            isa = XCBuildConfiguration;\n            buildSettings = {\n                ALWAYS_SEARCH_USER_PATHS = NO;\n                ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;\n                CLANG_ANALYZER_NONNULL = YES;\n                CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;\n                CLANG_CXX_LANGUAGE_STANDARD = \"gnu++20\";\n                CLANG_ENABLE_MODULES = YES;\n                CLANG_ENABLE_OBJC_ARC = YES;\n                CLANG_ENABLE_OBJC_WEAK = YES;\n                CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;\n                CLANG_WARN_BOOL_CONVERSION = YES;\n                CLANG_WARN_COMMA = YES;\n                CLANG_WARN_CONSTANT_CONVERSION = YES;\n                CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;\n                CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;\n                CLANG_WARN_DOCUMENTATION_COMMENTS = YES;\n                CLANG_WARN_EMPTY_BODY = YES;\n                CLANG_WARN_ENUM_CONVERSION = YES;\n                CLANG_WARN_INFINITE_RECURSION = YES;\n                CLANG_WARN_INT_CONVERSION = YES;\n                CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;\n                CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;\n                CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;\n                CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;\n                CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;\n                CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;\n                CLANG_WARN_STRICT_PROTOTYPES = YES;\n                CLANG_WARN_SUSPICIOUS_MOVE = YES;\n                CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;\n                CLANG_WARN_UNREACHABLE_CODE = YES;\n                CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;\n                COPY_PHASE_STRIP = NO;\n                DEBUG_INFORMATION_FORMAT = dwarf;\n                ENABLE_STRICT_OBJC_MSGSEND = YES;\n                ENABLE_TESTABILITY = YES;\n                ENABLE_USER_SCRIPT_SANDBOXING = YES;\n                GCC_C_LANGUAGE_STANDARD = gnu17;\n                GCC_DYNAMIC_NO_PIC = NO;\n                GCC_NO_COMMON_BLOCKS = YES;\n                GCC_OPTIMIZATION_LEVEL = 0;\n                GCC_PREPROCESSOR_DEFINITIONS = (\n                    \"DEBUG=1\",\n                    \"$(inherited)\",\n                );\n                GCC_WARN_64_TO_32_BIT_CONVERSION = YES;\n                GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;\n                GCC_WARN_UNDECLARED_SELECTOR = YES;\n                GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;\n                GCC_WARN_UNUSED_FUNCTION = YES;\n                GCC_WARN_UNUSED_VARIABLE = YES;\n                IPHONEOS_DEPLOYMENT_TARGET = 17.4;\n                LOCALIZATION_PREFERS_STRING_CATALOGS = YES;\n                MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE;\n                MTL_FAST_MATH = YES;\n                ONLY_ACTIVE_ARCH = YES;\n                SDKROOT = iphoneos;\n                SWIFT_ACTIVE_COMPILATION_CONDITIONS = \"DEBUG $(inherited)\";\n                SWIFT_OPTIMIZATION_LEVEL = \"-Onone\";\n            };\n            name = Debug;\n        };\n        C0B37B922BE8226B00B2F80B /* Release */ = {\n            isa = XCBuildConfiguration;\n            buildSettings = {\n                ALWAYS_SEARCH_USER_PATHS = NO;\n                ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES;\n                CLANG_ANALYZER_NONNULL = YES;\n                CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;\n                CLANG_CXX_LANGUAGE_STANDARD = \"gnu++20\";\n                CLANG_ENABLE_MODULES = YES;\n                CLANG_ENABLE_OBJC_ARC = YES;\n                CLANG_ENABLE_OBJC_WEAK = YES;\n                CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES;\n                CLANG_WARN_BOOL_CONVERSION = YES;\n                CLANG_WARN_COMMA = YES;\n                CLANG_WARN_CONSTANT_CONVERSION = YES;\n                CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES;\n                CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;\n                CLANG_WARN_DOCUMENTATION_COMMENTS = YES;\n                CLANG_WARN_EMPTY_BODY = YES;\n                CLANG_WARN_ENUM_CONVERSION = YES;\n                CLANG_WARN_INFINITE_RECURSION = YES;\n                CLANG_WARN_INT_CONVERSION = YES;\n                CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES;\n                CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES;\n                CLANG_WARN_OBJC_LITERAL_CONVERSION = YES;\n                CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;\n                CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES;\n                CLANG_WARN_RANGE_LOOP_ANALYSIS = YES;\n                CLANG_WARN_STRICT_PROTOTYPES = YES;\n                CLANG_WARN_SUSPICIOUS_MOVE = YES;\n                CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE;\n                CLANG_WARN_UNREACHABLE_CODE = YES;\n                CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;\n                COPY_PHASE_STRIP = NO;\n                DEBUG_INFORMATION_FORMAT = \"dwarf-with-dsym\";\n                ENABLE_NS_ASSERTIONS = NO;\n                ENABLE_STRICT_OBJC_MSGSEND = YES;\n                ENABLE_USER_SCRIPT_SANDBOXING = YES;\n                GCC_C_LANGUAGE_STANDARD = gnu17;\n                GCC_NO_COMMON_BLOCKS = YES;\n                GCC_WARN_64_TO_32_BIT_CONVERSION = YES;\n                GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;\n                GCC_WARN_UNDECLARED_SELECTOR = YES;\n                GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;\n                GCC_WARN_UNUSED_FUNCTION = YES;\n                GCC_WARN_UNUSED_VARIABLE = YES;\n                IPHONEOS_DEPLOYMENT_TARGET = 17.4;\n                LOCALIZATION_PREFERS_STRING_CATALOGS = YES;\n                MTL_ENABLE_DEBUG_INFO = NO;\n                MTL_FAST_MATH = YES;\n                SDKROOT = iphoneos;\n                SWIFT_COMPILATION_MODE = wholemodule;\n                VALIDATE_PRODUCT = YES;\n            };\n            name = Release;\n        };\n        C0B37B942BE8226B00B2F80B /* Debug */ = {\n            isa = XCBuildConfiguration;\n            buildSettings = {\n                ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;\n                ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;\n                CODE_SIGN_ENTITLEMENTS = MLCEngineExample/MLCEngineExample.entitlements;\n                CODE_SIGN_STYLE = Automatic;\n                CURRENT_PROJECT_VERSION = 1;\n                DEVELOPMENT_ASSET_PATHS = \"\\\"MLCEngineExample/Preview Content\\\"\";\n                DEVELOPMENT_TEAM = 3FR42MXLK9;\n                ENABLE_PREVIEWS = YES;\n                GENERATE_INFOPLIST_FILE = YES;\n                INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;\n                INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;\n                INFOPLIST_KEY_UILaunchScreen_Generation = YES;\n                INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = \"UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight\";\n                INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = \"UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight\";\n                IPHONEOS_DEPLOYMENT_TARGET = 16.0;\n                LD_RUNPATH_SEARCH_PATHS = (\n                    \"$(inherited)\",\n                    \"@executable_path/Frameworks\",\n                );\n                LIBRARY_SEARCH_PATHS = \"${PROJECT_DIR}/dist/lib\";\n                MARKETING_VERSION = 1.0;\n                OTHER_LDFLAGS = (\n                    \"-Wl,-all_load\",\n                    \"-lmodel_iphone\",\n                    \"-lmlc_llm\",\n                    \"-ltvm_runtime\",\n                    \"-ltokenizers_cpp\",\n                    \"-lsentencepiece\",\n                    \"-ltokenizers_c\",\n                );\n                PRODUCT_BUNDLE_IDENTIFIER = mlc.MLCEngineExample;\n                PRODUCT_NAME = \"$(TARGET_NAME)\";\n                SWIFT_EMIT_LOC_STRINGS = YES;\n                SWIFT_VERSION = 5.0;\n                TARGETED_DEVICE_FAMILY = \"1,2\";\n            };\n            name = Debug;\n        };\n        C0B37B952BE8226B00B2F80B /* Release */ = {\n            isa = XCBuildConfiguration;\n            buildSettings = {\n                ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;\n                ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor;\n                CODE_SIGN_ENTITLEMENTS = MLCEngineExample/MLCEngineExample.entitlements;\n                CODE_SIGN_STYLE = Automatic;\n                CURRENT_PROJECT_VERSION = 1;\n                DEVELOPMENT_ASSET_PATHS = \"\\\"MLCEngineExample/Preview Content\\\"\";\n                DEVELOPMENT_TEAM = 3FR42MXLK9;\n                ENABLE_PREVIEWS = YES;\n                GENERATE_INFOPLIST_FILE = YES;\n                INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES;\n                INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES;\n                INFOPLIST_KEY_UILaunchScreen_Generation = YES;\n                INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = \"UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight\";\n                INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = \"UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight\";\n                IPHONEOS_DEPLOYMENT_TARGET = 16.0;\n                LD_RUNPATH_SEARCH_PATHS = (\n                    \"$(inherited)\",\n                    \"@executable_path/Frameworks\",\n                );\n                LIBRARY_SEARCH_PATHS = \"${PROJECT_DIR}/dist/lib\";\n                MARKETING_VERSION = 1.0;\n                OTHER_LDFLAGS = (\n                    \"-Wl,-all_load\",\n                    \"-lmodel_iphone\",\n                    \"-lmlc_llm\",\n                    \"-ltvm_runtime\",\n                    \"-ltokenizers_cpp\",\n                    \"-lsentencepiece\",\n                    \"-ltokenizers_c\",\n                );\n                PRODUCT_BUNDLE_IDENTIFIER = mlc.MLCEngineExample;\n                PRODUCT_NAME = \"$(TARGET_NAME)\";\n                SWIFT_EMIT_LOC_STRINGS = YES;\n                SWIFT_VERSION = 5.0;\n                TARGETED_DEVICE_FAMILY = \"1,2\";\n            };\n            name = Release;\n        };\n/* End XCBuildConfiguration section */\n\n/* Begin XCConfigurationList section */\n        C0B37B802BE8226A00B2F80B /* Build configuration list for PBXProject \"MLCEngineExample\" */ = {\n            isa = XCConfigurationList;\n            buildConfigurations = (\n                C0B37B912BE8226B00B2F80B /* Debug */,\n                C0B37B922BE8226B00B2F80B /* Release */,\n            );\n            defaultConfigurationIsVisible = 0;\n            defaultConfigurationName = Release;\n        };\n        C0B37B932BE8226B00B2F80B /* Build configuration list for PBXNativeTarget \"MLCEngineExample\" */ = {\n            isa = XCConfigurationList;\n            buildConfigurations = (\n                C0B37B942BE8226B00B2F80B /* Debug */,\n                C0B37B952BE8226B00B2F80B /* Release */,\n            );\n            defaultConfigurationIsVisible = 0;\n            defaultConfigurationName = Release;\n        };\n/* End XCConfigurationList section */\n\n/* Begin XCLocalSwiftPackageReference section */\n        C0B37B962BE8234D00B2F80B /* XCLocalSwiftPackageReference \"../MLCSwift\" */ = {\n            isa = XCLocalSwiftPackageReference;\n            relativePath = ../MLCSwift;\n        };\n/* End XCLocalSwiftPackageReference section */\n\n/* Begin XCSwiftPackageProductDependency section */\n        C04105DE2BEBC61B005A434D /* MLCSwift */ = {\n            isa = XCSwiftPackageProductDependency;\n            productName = MLCSwift;\n        };\n        C0B37B972BE8234D00B2F80B /* MLCSwift */ = {\n            isa = XCSwiftPackageProductDependency;\n            productName = MLCSwift;\n        };\n/* End XCSwiftPackageProductDependency section */\n    };\n    rootObject = C0B37B7D2BE8226A00B2F80B /* Project object */;\n}\n"
  },
  {
    "path": "ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.xcworkspace/contents.xcworkspacedata",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<Workspace\n   version = \"1.0\">\n   <FileRef\n      location = \"self:\">\n   </FileRef>\n</Workspace>\n"
  },
  {
    "path": "ios/MLCEngineExample/MLCEngineExample.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<!DOCTYPE plist PUBLIC \"-//Apple//DTD PLIST 1.0//EN\" \"http://www.apple.com/DTDs/PropertyList-1.0.dtd\">\n<plist version=\"1.0\">\n<dict>\n    <key>IDEDidComputeMac32BitWarning</key>\n    <true/>\n</dict>\n</plist>\n"
  },
  {
    "path": "ios/MLCEngineExample/README.md",
    "content": "# MLCEngine Example\n\nMinimal example of MLCSwift API.\n\nCheckout [Documentation page](https://llm.mlc.ai/docs/deploy/ios.html) for more information.\n\n- run `mlc_llm package`\n- open the Xcode project\n"
  },
  {
    "path": "ios/MLCEngineExample/mlc-package-config.json",
    "content": "{\n    \"device\": \"iphone\",\n    \"model_list\": [\n        {\n            \"model\": \"HF://mlc-ai/Llama-3-8B-Instruct-q3f16_1-MLC\",\n            \"model_id\": \"Llama-3-8B-Instruct-q3f16_1-MLC\",\n            \"estimated_vram_bytes\": 3316000000,\n            \"bundle_weight\": true,\n            \"model_lib\": \"llama_q3f16_1\"\n        }\n    ]\n}\n"
  },
  {
    "path": "ios/MLCSwift/Package.swift",
    "content": "// swift-tools-version:5.5\n// The swift-tools-version declares the minimum version of Swift required to build this package.\n\nimport PackageDescription\n\nlet package = Package(\n    name: \"MLCSwift\",\n    products: [\n        .library(\n            name: \"MLCSwift\",\n            targets: [\"MLCEngineObjC\", \"MLCSwift\"]\n        )\n    ],\n    dependencies: [],\n    targets: [\n        .target(\n            name: \"MLCEngineObjC\",\n            path: \"Sources/ObjC\",\n            cxxSettings: [\n                .headerSearchPath(\"../../tvm_home/include\"),\n                .headerSearchPath(\"../../tvm_home/3rdparty/tvm-ffi/include\"),\n                .headerSearchPath(\"../../tvm_home/3rdparty/tvm-ffi/3rdparty/dlpack/include\")\n            ]\n        ),\n        .target(\n            name: \"MLCSwift\",\n            dependencies: [\"MLCEngineObjC\"],\n            path: \"Sources/Swift\"\n        )\n    ],\n    cxxLanguageStandard: .cxx17\n)\n"
  },
  {
    "path": "ios/MLCSwift/README.md",
    "content": "# MLCSwift\n\nThis is a simple swift package that exposes the chat module to swift.\nCheckout our [documentation](https://llm.mlc.ai/docs/) for more examples.\n"
  },
  {
    "path": "ios/MLCSwift/Sources/ObjC/LLMEngine.mm",
    "content": "//\n//  LLMEngine.mm\n//  LLMEngine\n//\n#import <Foundation/Foundation.h>\n#import <UIKit/UIKit.h>\n#include <os/proc.h>\n\n#include \"LLMEngine.h\"\n\n#define TVM_USE_LIBBACKTRACE 0\n\n#include <tvm/ffi/extra/module.h>\n#include <tvm/ffi/function.h>\n#include <tvm/ffi/optional.h>\n#include <tvm/ffi/string.h>\n#include <tvm/runtime/module.h>\n\nusing namespace tvm::runtime;\nusing tvm::ffi::Function;\nusing tvm::ffi::Module;\nusing tvm::ffi::Optional;\nusing tvm::ffi::String;\nusing tvm::ffi::TypedFunction;\n\n@implementation JSONFFIEngine {\n  // Internal c++ classes\n  // internal module backed by JSON FFI\n  Optional<Module> json_ffi_engine_;\n  // member functions\n  Function init_background_engine_func_;\n  Function unload_func_;\n  Function reload_func_;\n  Function reset_func_;\n  Function chat_completion_func_;\n  Function abort_func_;\n  Function run_background_loop_func_;\n  Function run_background_stream_back_loop_func_;\n  Function exit_background_loop_func_;\n}\n\n- (instancetype)init {\n  if (self = [super init]) {\n    // load chat module\n    Function f_json_ffi_create = Function::GetGlobalRequired(\"mlc.json_ffi.CreateJSONFFIEngine\");\n    json_ffi_engine_ = f_json_ffi_create().cast<Module>();\n    init_background_engine_func_ =\n        json_ffi_engine_.value()->GetFunction(\"init_background_engine\").value_or(Function(nullptr));\n    reload_func_ = json_ffi_engine_.value()->GetFunction(\"reload\").value_or(Function(nullptr));\n    unload_func_ = json_ffi_engine_.value()->GetFunction(\"unload\").value_or(Function(nullptr));\n    reset_func_ = json_ffi_engine_.value()->GetFunction(\"reset\").value_or(Function(nullptr));\n    chat_completion_func_ =\n        json_ffi_engine_.value()->GetFunction(\"chat_completion\").value_or(Function(nullptr));\n    abort_func_ = json_ffi_engine_.value()->GetFunction(\"abort\").value_or(Function(nullptr));\n    run_background_loop_func_ =\n        json_ffi_engine_.value()->GetFunction(\"run_background_loop\").value_or(Function(nullptr));\n    run_background_stream_back_loop_func_ = json_ffi_engine_.value()\n                                                ->GetFunction(\"run_background_stream_back_loop\")\n                                                .value_or(Function(nullptr));\n    exit_background_loop_func_ =\n        json_ffi_engine_.value()->GetFunction(\"exit_background_loop\").value_or(Function(nullptr));\n\n    TVM_FFI_ICHECK(init_background_engine_func_ != nullptr);\n    TVM_FFI_ICHECK(reload_func_ != nullptr);\n    TVM_FFI_ICHECK(unload_func_ != nullptr);\n    TVM_FFI_ICHECK(reset_func_ != nullptr);\n    TVM_FFI_ICHECK(chat_completion_func_ != nullptr);\n    TVM_FFI_ICHECK(abort_func_ != nullptr);\n    TVM_FFI_ICHECK(run_background_loop_func_ != nullptr);\n    TVM_FFI_ICHECK(run_background_stream_back_loop_func_ != nullptr);\n    TVM_FFI_ICHECK(exit_background_loop_func_ != nullptr);\n  }\n  return self;\n}\n\n- (void)initBackgroundEngine:(void (^)(NSString*))streamCallback {\n  TypedFunction<void(String)> internal_stream_callback([streamCallback](String value) {\n    streamCallback([NSString stringWithUTF8String:value.c_str()]);\n  });\n  int device_type = kDLMetal;\n  int device_id = 0;\n  init_background_engine_func_(device_type, device_id, internal_stream_callback);\n}\n\n- (void)reload:(NSString*)engineConfigJson {\n  std::string engine_config = engineConfigJson.UTF8String;\n  reload_func_(engine_config);\n}\n\n- (void)unload {\n  unload_func_();\n}\n\n- (void)reset {\n  reset_func_();\n}\n\n- (void)chatCompletion:(NSString*)requestJSON requestID:(NSString*)requestID {\n  std::string request_json = requestJSON.UTF8String;\n  std::string request_id = requestID.UTF8String;\n  chat_completion_func_(request_json, request_id);\n}\n\n- (void)abort:(NSString*)requestID {\n  std::string request_id = requestID.UTF8String;\n  abort_func_(request_id);\n}\n\n- (void)runBackgroundLoop {\n  run_background_loop_func_();\n}\n\n- (void)runBackgroundStreamBackLoop {\n  run_background_stream_back_loop_func_();\n}\n\n- (void)exitBackgroundLoop {\n  exit_background_loop_func_();\n}\n\n@end\n"
  },
  {
    "path": "ios/MLCSwift/Sources/ObjC/include/LLMEngine.h",
    "content": "//\n//  Use this file to import your target's public headers that you would like to expose to Swift.\n//  LLM Chat Module\n//\n// Exposed interface of Object-C, enables swift binding.\n#import <Foundation/Foundation.h>\n#import <UIKit/UIKit.h>\n\n/**\n * This is an internal Raw JSON FFI Engine that redirects request to internal JSON FFI Engine in C++\n */\n@interface JSONFFIEngine : NSObject\n\n- (void)initBackgroundEngine:(void (^)(NSString*))streamCallback;\n\n- (void)reload:(NSString*)engineConfig;\n\n- (void)unload;\n\n- (void)reset;\n\n- (void)chatCompletion:(NSString*)requestJSON requestID:(NSString*)requestID;\n\n- (void)abort:(NSString*)requestID;\n\n- (void)runBackgroundLoop;\n\n- (void)runBackgroundStreamBackLoop;\n\n- (void)exitBackgroundLoop;\n\n@end\n"
  },
  {
    "path": "ios/MLCSwift/Sources/Swift/LLMEngine.swift",
    "content": "import Foundation\nimport MLCEngineObjC\nimport os\n\nclass BackgroundWorker : Thread {\n    private var task: ()->Void;\n\n    public init(task: @escaping () -> Void) {\n        self.task = task\n    }\n\n    public override func main()  {\n        self.task();\n    }\n}\n\n@available(iOS 14.0.0, *)\npublic class MLCEngine {\n    struct RequestState {\n        let request: ChatCompletionRequest\n        let continuation: AsyncStream<ChatCompletionStreamResponse>.Continuation\n\n        init(\n            request: ChatCompletionRequest,\n            continuation: AsyncStream<ChatCompletionStreamResponse>.Continuation\n        ) {\n            self.request = request\n            self.continuation = continuation\n        }\n    }\n\n    // internal engine state\n    // that maintains logger and continuations\n    // we decouple it from MLCEngine\n    // and explicitly pass in jsonFFIEngine\n    // so there is no cyclic dependency\n    // when we capture things\n    actor EngineState {\n        public let logger = Logger()\n        private var requestStateMap = Dictionary<String, RequestState>()\n\n        // completion function\n        func chatCompletion(\n            jsonFFIEngine: JSONFFIEngine,\n            request: ChatCompletionRequest\n        ) -> AsyncStream<ChatCompletionStreamResponse> {\n            let encoder = JSONEncoder()\n            let data = try! encoder.encode(request)\n            let jsonRequest = String(data: data, encoding: .utf8)!\n            // generate a UUID for the request\n            let requestID = UUID().uuidString\n            let stream = AsyncStream(ChatCompletionStreamResponse.self) { continuation in\n                continuation.onTermination = { termination in\n                    if termination == .cancelled {\n                        jsonFFIEngine.abort(requestID);\n                    }\n                }\n                // store continuation map for further callbacks\n                self.requestStateMap[requestID] = RequestState(\n                    request: request, continuation: continuation\n                )\n                // start invoking engine for completion\n                jsonFFIEngine.chatCompletion(jsonRequest, requestID: requestID)\n            }\n            return stream\n        }\n\n        func streamCallback(result: String?) {\n            var responses: [ChatCompletionStreamResponse] = []\n\n            let decoder = JSONDecoder()\n            do {\n                responses = try decoder.decode([ChatCompletionStreamResponse].self, from: result!.data(using: .utf8)!)\n            } catch let lastError {\n                logger.error(\"Swift json parsing error: error=\\(lastError), jsonsrc=\\(result!)\")\n             }\n\n            // dispatch to right request ID\n            for res in responses {\n                if let requestState = self.requestStateMap[res.id] {\n                    // final chunk always come with usage\n                    if let finalUsage = res.usage {\n                        if let include_usage = requestState.request.stream_options?.include_usage {\n                            if include_usage {\n                                requestState.continuation.yield(res)\n                            }\n                        }\n                        requestState.continuation.finish()\n                        self.requestStateMap.removeValue(forKey: res.id)\n                    } else {\n                        requestState.continuation.yield(res)\n                    }\n                }\n            }\n            // Todo(mlc-team): check the last error in engine and report if there's any\n        }\n    }\n\n    public class Completions {\n        private let jsonFFIEngine: JSONFFIEngine\n        private let state: EngineState\n\n        init(jsonFFIEngine: JSONFFIEngine, state: EngineState) {\n            self.jsonFFIEngine = jsonFFIEngine\n            self.state = state\n        }\n\n        private func create(\n            request: ChatCompletionRequest\n        ) async -> AsyncStream<ChatCompletionStreamResponse> {\n            return await state.chatCompletion(jsonFFIEngine: jsonFFIEngine, request: request)\n        }\n\n        // offer a direct convenient method to pass in messages\n        public func create(\n            messages: [ChatCompletionMessage],\n            model: Optional<String> = nil,\n            frequency_penalty: Optional<Float> = nil,\n            presence_penalty: Optional<Float> = nil,\n            logprobs: Bool = false,\n            top_logprobs: Int = 0,\n            logit_bias: Optional<[Int : Float]> = nil,\n            max_tokens: Optional<Int> = nil,\n            n: Int = 1,\n            seed: Optional<Int> = nil,\n            stop: Optional<[String]> = nil,\n            stream: Bool = true,\n            stream_options: Optional<StreamOptions> = nil,\n            temperature: Optional<Float> = nil,\n            top_p: Optional<Float> = nil,\n            tools: Optional<[ChatTool]> = nil,\n            user: Optional<String> = nil,\n            response_format: Optional<ResponseFormat> = nil\n        ) async -> AsyncStream<ChatCompletionStreamResponse> {\n            if !stream {\n                state.logger.error(\"Only stream=true is supported in MLCSwift\")\n            }\n            let request = ChatCompletionRequest(\n                messages: messages,\n                model: model,\n                frequency_penalty: frequency_penalty,\n                presence_penalty: presence_penalty,\n                logprobs: logprobs,\n                top_logprobs: top_logprobs,\n                logit_bias: logit_bias,\n                max_tokens: max_tokens,\n                n: n,\n                seed: seed,\n                stop: stop,\n                stream: stream,\n                stream_options: stream_options,\n                temperature: temperature,\n                top_p: top_p,\n                tools: tools,\n                user: user,\n                response_format: response_format\n            )\n            return await self.create(request: request)\n        }\n    }\n\n    public class Chat {\n        public let completions: Completions\n\n        init(jsonFFIEngine: JSONFFIEngine, state: EngineState) {\n            self.completions = Completions(\n                jsonFFIEngine: jsonFFIEngine,\n                state: state\n            )\n        }\n    }\n\n    private let state : EngineState;\n    private let jsonFFIEngine: JSONFFIEngine;\n    public let chat : Chat;\n    private var threads = Array<Thread>();\n\n    public init() {\n        let state_ = EngineState();\n        let jsonFFIEngine_ = JSONFFIEngine();\n\n        self.chat = Chat(jsonFFIEngine: jsonFFIEngine_, state: state_)\n        self.jsonFFIEngine = jsonFFIEngine_\n        self.state = state_\n\n        // note: closure do not capture self\n        jsonFFIEngine_.initBackgroundEngine {\n            [state_](result : String?) -> Void in\n            state_.streamCallback(result: result)\n        }\n        let backgroundWorker = BackgroundWorker { [jsonFFIEngine_] in\n            Thread.setThreadPriority(1)\n            jsonFFIEngine_.runBackgroundLoop()\n        }\n        let backgroundStreamBackWorker = BackgroundWorker {\n            [jsonFFIEngine_] in\n            jsonFFIEngine_.runBackgroundStreamBackLoop()\n        }\n        // set background worker to be high QoS so it gets higher p for gpu\n        backgroundWorker.qualityOfService = QualityOfService.userInteractive\n        threads.append(backgroundWorker)\n        threads.append(backgroundStreamBackWorker)\n        backgroundWorker.start()\n        backgroundStreamBackWorker.start()\n    }\n\n    deinit {\n        jsonFFIEngine.exitBackgroundLoop()\n    }\n\n    // The following functions do not have to be async for now\n    // But to be safe and consistent with chat.completions.create\n    // and for future API changes we keep them as async calls\n    public func reload(modelPath: String, modelLib: String) async {\n        let engineConfig = \"\"\"\n        {\n            \"model\": \"\\(modelPath)\",\n            \"model_lib\": \"system://\\(modelLib)\",\n            \"mode\": \"interactive\"\n        }\n        \"\"\"\n        jsonFFIEngine.reload(engineConfig)\n    }\n\n    public func reset() async {\n        jsonFFIEngine.reset()\n    }\n\n    public func unload() async {\n        jsonFFIEngine.unload()\n    }\n}\n"
  },
  {
    "path": "ios/MLCSwift/Sources/Swift/OpenAIProtocol.swift",
    "content": "// Protocol definition of OpenAI API\nimport Foundation\n\n// Protocols for v1/chat/completions\n// API reference: https://platform.openai.com/docs/api-reference/chat/create\n\npublic struct TopLogProbs : Codable {\n    public var token: String\n    public var logprob: Float\n    public var bytes: Optional<[Int]>\n}\n\npublic struct LogProbsContent : Codable {\n    public var token: String\n    public var logprob: Float\n    public var bytes: Optional<[Int]> = nil\n    public var top_logprobs: [TopLogProbs] = []\n}\n\npublic struct LogProbs : Codable {\n    public var content: [LogProbsContent] = []\n}\n\npublic struct ChatFunction : Codable {\n    public var name: String\n    public var description: Optional<String> = nil\n    public var parameters: [String: String]\n\n    public init(\n        name: String,\n        description: Optional<String> = nil,\n        parameters: [String : String]\n    ) {\n        self.name = name\n        self.description = description\n        self.parameters = parameters\n    }\n}\n\npublic struct ChatTool : Codable {\n    public var type: String = \"function\"\n    public let function: ChatFunction\n\n    public init(type: String, function: ChatFunction) {\n        self.type = type\n        self.function = function\n    }\n}\n\npublic struct ChatFunctionCall : Codable {\n    public var name: String\n    // NOTE: arguments shold be dict str to any codable\n    // for now only allow string output due to typing issues\n    public var arguments: Optional<[String: String]> = nil\n\n    public init(name: String, arguments: Optional<[String : String]> = nil) {\n        self.name = name\n        self.arguments = arguments\n    }\n}\n\npublic struct ChatToolCall : Codable {\n    public var id: String = UUID().uuidString\n    public var type: String = \"function\"\n    public var function: ChatFunctionCall\n\n    public init(\n        id: String = UUID().uuidString,\n        type: String = \"function\",\n        function: ChatFunctionCall\n    ) {\n        self.id = id\n        self.type = type\n        self.function = function\n    }\n}\n\npublic enum ChatCompletionRole: String, Codable {\n    case system = \"system\"\n    case user = \"user\"\n    case assistant = \"assistant\"\n    case tool = \"tool\"\n}\n\npublic enum ChatCompletionMessageContent: Codable {\n    case text(String)\n    case parts([[String: String]])\n\n    public init(from decoder: Decoder) throws {\n        let container = try decoder.singleValueContainer()\n        if let text = try? container.decode(String.self) {\n            self = .text(text)\n        } else {\n            let parts = try container.decode([[String: String]].self)\n            self = .parts(parts)\n        }\n    }\n\n    public func encode(to encoder: Encoder) throws {\n        var container = encoder.singleValueContainer()\n        switch self {\n        case .text(let text): try container.encode(text)\n        case .parts(let parts): try container.encode(parts)\n        }\n    }\n\n    public func asText() -> String {\n        switch (self) {\n        case .text(let text): return text\n        case .parts(let parts):\n            var res = \"\"\n            for item in parts {\n                if item[\"type\"]! == \"text\" {\n                    res += item[\"text\"]!\n                }\n            }\n            return res\n        }\n    }\n}\n\npublic struct ChatCompletionMessage: Codable {\n    public var role: ChatCompletionRole\n    public var content: Optional<ChatCompletionMessageContent> = nil\n    public var name: Optional<String> = nil\n    public var tool_calls: Optional<[ChatToolCall]> = nil\n    public var tool_call_id: Optional<String> = nil\n\n    // more complicated content construction\n    public init(\n        role: ChatCompletionRole,\n        content: Optional<[[String : String]]> = nil,\n        name: Optional<String> = nil,\n        tool_calls: Optional<[ChatToolCall]> = nil,\n        tool_call_id: Optional<String> = nil\n    ) {\n        self.role = role\n        if let cvalue = content {\n            self.content = .parts(cvalue)\n        } else {\n            self.content = nil\n        }\n        self.name = name\n        self.tool_calls = tool_calls\n        self.tool_call_id = tool_call_id\n    }\n\n    // convenient method to construct content from string\n    public init(\n        role: ChatCompletionRole,\n        content: String,\n        name: Optional<String> = nil,\n        tool_calls: Optional<[ChatToolCall]> = nil,\n        tool_call_id: Optional<String> = nil\n    ) {\n        self.role = role\n        self.content = .text(content)\n        self.name = name\n        self.tool_calls = tool_calls\n        self.tool_call_id = tool_call_id\n    }\n}\n\npublic struct ChatCompletionStreamResponseChoice: Codable {\n    public var finish_reason: Optional<String> = nil\n    public var index: Int\n    public var delta: ChatCompletionMessage\n    public var lobprobs: Optional<LogProbs> = nil\n}\n\npublic struct CompletionUsageExtra: Codable {\n    public var prefill_tokens_per_s: Optional<Float> = nil\n    public var decode_tokens_per_s: Optional<Float> = nil\n    public var num_prefill_tokens: Optional<Int> = nil\n\n    public func asTextLabel() -> String {\n        var outputText = \"\"\n        if let prefill_tokens_per_s = self.prefill_tokens_per_s {\n            outputText += \"prefill: \"\n            outputText += String(format: \"%.1f\", prefill_tokens_per_s)\n            outputText += \" tok/s\"\n        }\n        if let decode_tokens_per_s = self.decode_tokens_per_s {\n            if !outputText.isEmpty {\n                outputText += \", \"\n            }\n            outputText += \"decode: \"\n            outputText += String(format: \"%.1f\", decode_tokens_per_s)\n            outputText += \" tok/s\"\n        }\n        return outputText\n    }\n}\n\npublic struct CompletionUsage: Codable {\n    public var prompt_tokens: Int\n    public var completion_tokens: Int\n    public var total_tokens: Int\n    public var extra: Optional<CompletionUsageExtra>\n}\n\npublic struct ChatCompletionStreamResponse: Codable {\n    public var id : String\n    public var choices: [ChatCompletionStreamResponseChoice] = []\n    public var created: Optional<Int> = nil\n    public var model: Optional<String> = nil\n    public var system_fingerprint: String\n    public var object: Optional<String> = nil\n    public var usage: Optional<CompletionUsage> = nil\n}\n\npublic struct ResponseFormat: Codable {\n    public var type: String\n    public var schema: Optional<String> = nil\n\n    public init(type: String, schema: Optional<String> = nil) {\n        self.type = type\n        self.schema = schema\n    }\n}\n\npublic struct StreamOptions: Codable {\n    public var include_usage: Bool = false\n\n    public init(include_usage: Bool) {\n        self.include_usage = include_usage\n    }\n}\n\npublic struct ChatCompletionRequest: Codable {\n    public var messages: [ChatCompletionMessage]\n    public var model: Optional<String> = nil\n    public var frequency_penalty: Optional<Float> = nil\n    public var presence_penalty: Optional<Float> = nil\n    public var logprobs: Bool = false\n    public var top_logprobs: Int = 0\n    public var logit_bias: Optional<[Int: Float]> = nil\n    public var max_tokens: Optional<Int> = nil\n    public var n: Int = 1\n    public var seed: Optional<Int> = nil\n    public var stop: Optional<[String]> = nil\n    public var stream: Bool = true\n    public var stream_options: Optional<StreamOptions> = nil\n    public var temperature: Optional<Float> = nil\n    public var top_p: Optional<Float> = nil\n    public var tools: Optional<[ChatTool]> = nil\n    public var user: Optional<String> = nil\n    public var response_format: Optional<ResponseFormat> = nil\n\n    public init(\n        messages: [ChatCompletionMessage],\n        model: Optional<String> = nil,\n        frequency_penalty: Optional<Float> = nil,\n        presence_penalty: Optional<Float> = nil,\n        logprobs: Bool = false,\n        top_logprobs: Int = 0,\n        logit_bias: Optional<[Int : Float]> = nil,\n        max_tokens: Optional<Int> = nil,\n        n: Int = 1,\n        seed: Optional<Int> = nil,\n        stop: Optional<[String]> = nil,\n        stream: Bool = true,\n        stream_options: Optional<StreamOptions> = nil,\n        temperature: Optional<Float> = nil,\n        top_p: Optional<Float> = nil,\n        tools: Optional<[ChatTool]> = nil,\n        user: Optional<String> = nil,\n        response_format: Optional<ResponseFormat> = nil\n    ) {\n        self.messages = messages\n        self.model = model\n        self.frequency_penalty = frequency_penalty\n        self.presence_penalty = presence_penalty\n        self.logprobs = logprobs\n        self.top_logprobs = top_logprobs\n        self.logit_bias = logit_bias\n        self.max_tokens = max_tokens\n        self.n = n\n        self.seed = seed\n        self.stop = stop\n        self.stream = stream\n        self.stream_options = stream_options\n        self.temperature = temperature\n        self.top_p = top_p\n        self.tools = tools\n        self.user = user\n        self.response_format = response_format\n    }\n}\n"
  },
  {
    "path": "ios/README.md",
    "content": "# MLC-LLM iOS\n\n[Documentation page](https://llm.mlc.ai/docs/deploy/ios.html)\n"
  },
  {
    "path": "ios/prepare_libs.sh",
    "content": "# Command to prepare the mlc llm static libraries\n# This command will be invoked by the \"mlc_llm package\" command\nfunction help {\n    echo -e \"OPTION:\"\n    echo -e \"  -s, --simulator                      Build for Simulator\"\n    echo -e \"  -a, --arch        x86_64 | arm64     Simulator arch \"\n    echo -e \"  -c, --catalyst                       Build for Mac Catalyst (arm64 only)\"\n    echo -e \"      --deployment-target VERSION     Mac Catalyst deployment target (default: 18.0)\"\n    echo -e \"  -h,  --help                          Prints this help\\n\"\n}\n\nMLC_LLM_SOURCE_DIR=\"${MLC_LLM_SOURCE_DIR:-..}\"\nis_simulator=\"false\"\nis_catalyst=\"false\"\narch=\"arm64\"\ndeployment_target=\"18.0\"\n\n# rustup is required to install iOS target stdlibs, and we need to make sure\n# the rustup-managed cargo/rustc are used during cross compilation.\nif [ -d \"${HOME}/.cargo/bin\" ]; then\n  export PATH=\"${HOME}/.cargo/bin:${PATH}\"\nfi\nif ! command -v rustup >/dev/null 2>&1; then\n  echo \"error: rustup is required to build iOS static libraries.\" >&2\n  echo \"Install rustup and retry, e.g.:\" >&2\n  echo \"  curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y\" >&2\n  exit 1\nfi\n\n# Args while-loop\nwhile [ \"$1\" != \"\" ];\ndo\n   case $1 in\n   -s  | --simulator  )   is_simulator=\"true\"\n                          ;;\n   -c  | --catalyst  )    is_catalyst=\"true\"\n                          ;;\n   -a  | --arch  )        shift\n                          arch=$1\n                          ;;\n   --deployment-target )  shift\n                          deployment_target=$1\n                          ;;\n   -h   | --help )        help\n                          exit\n                          ;;\n   *)\n                          echo \"$script: illegal option $1\"\n                          usage\n                                          exit 1 # error\n                          ;;\n    esac\n    shift\ndone\n\nset -euxo pipefail\n\nsysroot=\"iphoneos\"\ntype=\"Release\"\nbuild_dir=\"build\"\n\nif [ \"$is_catalyst\" = \"true\" ]; then\n  if [ \"$is_simulator\" = \"true\" ]; then\n    echo \"error: --simulator is not supported with --catalyst.\" >&2\n    exit 1\n  fi\n  if [ \"$arch\" != \"x86_64\" ]; then\n    arch=\"arm64\"\n  fi\n  sysroot=\"macosx\"\n  build_dir=\"build-maccatalyst-$arch\"\nfi\n\nif [ \"$is_simulator\" = \"true\" ]; then\n  if [ \"$arch\" = \"arm64\" ]; then\n    # iOS simulator on Apple processors\n    rustup target add aarch64-apple-ios-sim\n  else\n    # iOS simulator on x86 processors\n    rustup target add x86_64-apple-ios\n  fi\n  sysroot=\"iphonesimulator\"\n  type=\"Debug\"\nelse\n  # iOS devices\n  rustup target add aarch64-apple-ios\n  if [ \"$is_catalyst\" = \"true\" ]; then\n    if [ \"$arch\" = \"x86_64\" ]; then\n      rustup target add x86_64-apple-ios-macabi\n    else\n      rustup target add aarch64-apple-ios-macabi\n    fi\n  fi\nfi\n\nmkdir -p \"$build_dir\" && cd \"$build_dir\"\n\ncmake_args=(\n  -DCMAKE_BUILD_TYPE=\"$type\"\n  -DCMAKE_BUILD_WITH_INSTALL_NAME_DIR=ON\n  -DCMAKE_SKIP_INSTALL_ALL_DEPENDENCY=ON\n  -DCMAKE_INSTALL_PREFIX=.\n  -DCMAKE_CXX_FLAGS=\"-O3\"\n  -DMLC_LLM_INSTALL_STATIC_LIB=ON\n  -DUSE_METAL=ON\n  -DTVM_FFI_USE_LIBBACKTRACE=OFF\n  -DTVM_FFI_BACKTRACE_ON_SEGFAULT=OFF\n)\n\nif [ \"$is_catalyst\" = \"true\" ]; then\n  toolchain=\"$MLC_LLM_SOURCE_DIR/3rdparty/tokenizers-cpp/sentencepiece/cmake/ios.toolchain.cmake\"\n  if [ \"$arch\" = \"x86_64\" ]; then\n    platform=\"MAC_CATALYST\"\n  else\n    platform=\"MAC_CATALYST_ARM64\"\n  fi\n  cmake_args+=(\n    -DCMAKE_TOOLCHAIN_FILE=\"$toolchain\"\n    -DPLATFORM=\"$platform\"\n    -DDEPLOYMENT_TARGET=\"$deployment_target\"\n    -DENABLE_BITCODE=OFF\n  )\nelse\n  cmake_args+=(\n    -DCMAKE_SYSTEM_NAME=iOS\n    -DCMAKE_SYSTEM_VERSION=14.0\n    -DCMAKE_OSX_SYSROOT=\"$sysroot\"\n    -DCMAKE_OSX_ARCHITECTURES=\"$arch\"\n    -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0\n  )\nfi\n\ncmake \"$MLC_LLM_SOURCE_DIR\" \"${cmake_args[@]}\"\n\n\ncmake --build . --config release --target mlc_llm_static -j\ncmake --build . --target install --config release -j\ncd ..\n\nrm -rf $MLC_LLM_SOURCE_DIR/ios/MLCSwift/tvm_home\nln -s $MLC_LLM_SOURCE_DIR/3rdparty/tvm $MLC_LLM_SOURCE_DIR/ios/MLCSwift/tvm_home\n"
  },
  {
    "path": "pyproject.toml",
    "content": "# Licensed to the Apache Software Foundation (ASF) under one\n# or more contributor license agreements.  See the NOTICE file\n# distributed with this work for additional information\n# regarding copyright ownership.  The ASF licenses this file\n# to you under the Apache License, Version 2.0 (the\n# \"License\"); you may not use this file except in compliance\n# with the License.  You may obtain a copy of the License at\n#\n#   http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied.  See the License for the\n# specific language governing permissions and limitations\n# under the License.\n\n[project]\nname = \"mlc_llm\"\n# Note: Call version.py to update the version before building the wheel\nversion = \"0.20.0.dev0\"\ndescription = \"MLC LLM: an universal LLM deployment engine via ML compilation.\"\n\nauthors = [{ name = \"MLC LLM Contributors\" }]\nreadme = \"README.md\"\nlicense = { text = \"Apache 2.0\" }\nclassifiers = [\n  \"License :: OSI Approved :: Apache Software License\",\n  \"Development Status :: 4 - Beta\",\n  \"Intended Audience :: Developers\",\n  \"Intended Audience :: Education\",\n  \"Intended Audience :: Science/Research\",\n]\nkeywords = [\"machine learning\"]\nrequires-python = \">=3.9\"\n\ndependencies = [\n    \"apache-tvm-ffi\",\n    \"datasets\",\n    \"fastapi\",\n    \"flashinfer-python; sys_platform == 'linux'\",\n    \"ml_dtypes>=0.5.1\",\n    \"openai\",\n    \"pandas\",\n    \"prompt_toolkit\",\n    \"requests\",\n    \"safetensors\",\n    \"sentencepiece\",\n    \"shortuuid\",\n    \"tiktoken\",\n    \"torch\",\n    \"tqdm\",\n    \"transformers\",\n    \"uvicorn\",\n]\n\n[project.urls]\nHomepage = \"https://llm.mlc.ai/\"\nDocumentation = \"https://llm.mlc.ai/docs/\"\nRepository = \"https://github.com/mlc-ai/mlc-llm\"\n\"Bug Tracker\" = \"https://github.com/mlc-ai/mlc-llm/issues\"\n\n\n\n[build-system]\nrequires = [\"scikit-build-core>=0.10.0\"]\nbuild-backend = \"scikit_build_core.build\"\n\n[tool.scikit-build]\n# Point to the root CMakeLists.txt\ncmake.source-dir = \".\"\ncmake.build-type = \"Release\"\n\n# Configure the wheel to be Python version-agnostic\nwheel.py-api = \"py3\"\n\n# Build configuration\nbuild-dir = \"build\"\n\n# CMake configuration - ensure proper installation paths\ncmake.args = [\"-DMLC_LLM_BUILD_PYTHON_MODULE=ON\"]\n\n# Wheel configuration\nwheel.packages = [\"python/mlc_llm\"]\nwheel.install-dir = \"mlc_llm\"\n\n# Source distribution configuration\nsdist.include = [\n    # Build files\n    \"/CMakeLists.txt\",\n    \"/pyproject.toml\",\n    \"/cmake/**/*\",\n    \"/3rdparty/**/*\",\n\n    # Source code\n    \"/src/**/*.cc\",\n    \"/src/**/*.h\",\n\n    # Python source\n    \"/python/mlc_llm/**/*.py\",\n\n    # Documentation and metadata\n    \"/docs/**/*\",\n    \"/LICENSE\",\n    \"/README.md\",\n    \"/NOTICE\",\n\n    # Tests\n    \"/tests/**/*\",\n\n    \"/.pre-commit-config.yaml\",\n    \"/.pylintrc\",\n]\n\nsdist.exclude = [\n    \"**/.git\",\n    \"**/.github\",\n    \"**/__pycache__\",\n    \"**/*.pyc\",\n    \"build\",\n    \"dist\",\n    \"3rdparty/tvm\",\n    \"**/3rdparty/*/docs\",\n    \"**/3rdparty/*/media\",\n    \"**/3rdparty/*/examples\",\n    \"**/3rdparty/*/test\",\n]\n\n# Logging\nlogging.level = \"INFO\"\n\n\n[tool.isort]\nprofile = \"black\"\nsrc_paths = [\"python/mlc_llm\"]\nknown_third_party = [\"numpy\", \"tvm\", \"tqdm\", \"torch\", \"transformers\"]\n\n[tool.black]\nline-length = 100\n\n[tool.mypy]\nignore_missing_imports = true\nshow_column_numbers = true\nshow_error_context = true\nfollow_imports = \"skip\"\nignore_errors = false\nstrict_optional = false\n\n[tool.pylint.messages_control]\nmax-line-length = 100\ndisable = \"\"\"\nduplicate-code,\n\"\"\"\n"
  },
  {
    "path": "python/mlc_llm/__init__.py",
    "content": "\"\"\"MLC Chat python package.\n\nMLC Chat is the app runtime of MLC LLM.\n\"\"\"\n\nfrom tvm import register_global_func\n\nfrom . import protocol, serve\nfrom .libinfo import __version__\nfrom .serve import AsyncMLCEngine, MLCEngine\n\n\n@register_global_func(\"runtime.disco.create_socket_session_local_workers\", override=True)\ndef _create_socket_session_local_workers(num_workers):\n    \"\"\"Create the local session for each distributed node over socket session.\"\"\"\n    from tvm.runtime.disco import (  # pylint: disable=import-outside-toplevel\n        ProcessSession,\n    )\n\n    return ProcessSession(num_workers, num_groups=1, entrypoint=\"mlc_llm.cli.worker\")\n"
  },
  {
    "path": "python/mlc_llm/__main__.py",
    "content": "\"\"\"Entrypoint of all CLI commands from MLC LLM\"\"\"\n\nimport sys\n\nfrom mlc_llm.support import logging\nfrom mlc_llm.support.argparse import ArgumentParser\n\nlogging.enable_logging()\n\n\ndef main():\n    \"\"\"Entrypoint of all CLI commands from MLC LLM\"\"\"\n    parser = ArgumentParser(\"MLC LLM Command Line Interface.\")\n    parser.add_argument(\n        \"subcommand\",\n        type=str,\n        choices=[\n            \"compile\",\n            \"convert_weight\",\n            \"gen_config\",\n            \"chat\",\n            \"serve\",\n            \"package\",\n            \"calibrate\",\n            \"router\",\n        ],\n        help=\"Subcommand to to run. (choices: %(choices)s)\",\n    )\n    parsed = parser.parse_args(sys.argv[1:2])\n    # pylint: disable=import-outside-toplevel\n    if parsed.subcommand == \"compile\":\n        from mlc_llm.cli import compile as cli\n\n        cli.main(sys.argv[2:])\n    elif parsed.subcommand == \"convert_weight\":\n        from mlc_llm.cli import convert_weight as cli\n\n        cli.main(sys.argv[2:])\n    elif parsed.subcommand == \"gen_config\":\n        from mlc_llm.cli import gen_config as cli\n\n        cli.main(sys.argv[2:])\n    elif parsed.subcommand == \"chat\":\n        from mlc_llm.cli import chat as cli\n\n        cli.main(sys.argv[2:])\n    elif parsed.subcommand == \"serve\":\n        from mlc_llm.cli import serve as cli\n\n        cli.main(sys.argv[2:])\n    elif parsed.subcommand == \"package\":\n        from mlc_llm.cli import package as cli\n\n        cli.main(sys.argv[2:])\n    elif parsed.subcommand == \"calibrate\":\n        from mlc_llm.cli import calibrate as cli\n\n        cli.main(sys.argv[2:])\n    elif parsed.subcommand == \"router\":\n        from mlc_llm.cli import router as cli\n\n        cli.main(sys.argv[2:])\n    else:\n        raise ValueError(f\"Unknown subcommand {parsed.subcommand}\")\n    # pylint: enable=import-outside-toplevel\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/mlc_llm/base.py",
    "content": "\"\"\"Load MLC LLM library and _ffi_api functions.\"\"\"\n\nimport ctypes\nimport os\nimport sys\n\nimport tvm\nimport tvm.base\n\nfrom . import libinfo\n\nSKIP_LOADING_MLCLLM_SO = os.environ.get(\"SKIP_LOADING_MLCLLM_SO\", \"0\")\n\n\ndef _load_mlc_llm_lib():\n    \"\"\"Load MLC LLM lib\"\"\"\n    if sys.platform.startswith(\"win32\") and sys.version_info >= (3, 8):\n        for path in libinfo.get_dll_directories():\n            os.add_dll_directory(path)\n    # pylint: disable=protected-access\n    lib_name = \"mlc_llm\" if tvm.base._RUNTIME_ONLY else \"mlc_llm_module\"\n    # pylint: enable=protected-access\n    lib_path = libinfo.find_lib_path(lib_name, optional=False)\n    return ctypes.CDLL(lib_path[0]), lib_path[0]\n\n\n@tvm.register_global_func(\"mlc.debug_cuda_profiler_start\")\ndef _debug_cuda_profiler_start() -> None:\n    \"\"\"Start cuda profiler.\"\"\"\n    import cuda  # pylint: disable=import-outside-toplevel\n    import cuda.cudart  # pylint: disable=import-outside-toplevel,import-error,no-name-in-module\n\n    cuda.cudart.cudaProfilerStart()  # pylint: disable=c-extension-no-member\n\n\n@tvm.register_global_func(\"mlc.debug_cuda_profiler_stop\")\ndef _debug_cuda_profiler_stop() -> None:\n    \"\"\"Stop cuda profiler.\"\"\"\n    import cuda  # pylint: disable=import-outside-toplevel\n    import cuda.cudart  # pylint: disable=import-outside-toplevel,import-error,no-name-in-module\n\n    cuda.cudart.cudaProfilerStop()  # pylint: disable=c-extension-no-member\n\n\n# only load once here\nif SKIP_LOADING_MLCLLM_SO == \"0\":\n    _LIB, _LIB_PATH = _load_mlc_llm_lib()\n"
  },
  {
    "path": "python/mlc_llm/bench/__init__.py",
    "content": "\"\"\"Subdirectory of bench.\"\"\"\n"
  },
  {
    "path": "python/mlc_llm/bench/__main__.py",
    "content": "\"\"\"MLC LLM benchmark main entrance\"\"\"\n\nimport functools\nimport json\nimport random\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport numpy as np\nimport requests\nfrom transformers import AutoTokenizer  # pylint: disable=import-error\n\nimport mlc_llm\nfrom mlc_llm.bench.api_endpoint import SUPPORTED_BACKENDS, create_api_endpoint\nfrom mlc_llm.bench.dataset import SUPPORTED_DATASET, Dataset, create_dataset\nfrom mlc_llm.bench.request_processor import (\n    MetricAnalyzer,\n    RequestProcessor,\n    create_pipelines,\n)\nfrom mlc_llm.bench.request_record import (\n    RequestRecord,\n    convert_reports_to_df,\n    generate_metrics_summary,\n    pretty_print_report,\n)\nfrom mlc_llm.cli.serve import EngineConfigOverride\nfrom mlc_llm.serve import EngineConfig\nfrom mlc_llm.support import argparse, logging\n\nlogging.enable_logging()\nlogger = logging.getLogger(__name__)\n\n\ndef _parse_num_concurrent_requests(num_str: Optional[str]) -> Optional[List[int]]:\n    if num_str is None:\n        return None\n    numbers = num_str.split(\",\")\n    if any(not number.isdigit() for number in numbers):\n        raise ValueError(f\"Unrecognized num_concurrent_requests list: {numbers}\")\n    return list(int(number) for number in numbers)\n\n\ndef _parse_request_rate(request_rate_str: Optional[str]) -> Optional[List[np.float32]]:\n    if request_rate_str is None:\n        return None\n    request_rates = request_rate_str.split(\",\")\n    results = []\n    for rate_str in request_rates:\n        request_rate = float(rate_str)\n        if request_rate <= 0:\n            raise ValueError(f\"Invalid request rate {request_rate}\")\n        results.append(np.float32(request_rate))\n    return results\n\n\ndef _parse_mlc_engine_config(config_str: Optional[str]) -> EngineConfig:\n    if config_str is None:\n        return None\n    engine_config_override = EngineConfigOverride.from_str(config_str)\n    return EngineConfig(\n        tensor_parallel_shards=engine_config_override.tensor_parallel_shards,\n        max_num_sequence=engine_config_override.max_num_sequence,\n        max_total_sequence_length=engine_config_override.max_total_seq_length,\n        prefill_chunk_size=engine_config_override.prefill_chunk_size,\n        sliding_window_size=engine_config_override.sliding_window_size,\n        attention_sink_size=engine_config_override.attention_sink_size,\n        max_history_size=engine_config_override.max_history_size,\n        gpu_memory_utilization=engine_config_override.gpu_memory_utilization,\n        spec_draft_length=engine_config_override.spec_draft_length,\n        prefill_mode=engine_config_override.prefill_mode,\n        prefix_cache_max_num_recycling_seqs=engine_config_override.prefix_cache_max_num_recycling_seqs,  # pylint: disable=line-too-long\n        prefix_cache_mode=engine_config_override.prefix_cache_mode,\n    )\n\n\ndef _launch_mlc_server(args: argparse.argparse.Namespace):\n    return mlc_llm.serve.PopenServer(\n        model=args.tokenizer,\n        mode=\"server\",\n        model_lib=args.mlc_model_lib,\n        enable_tracing=False,\n        host=args.host,\n        port=args.port,\n        engine_config=args.mlc_engine_config,\n    )\n\n\ndef run_pipeline(\n    pipeline: RequestProcessor,\n    dataset: Dataset,\n    tokenizer: AutoTokenizer,\n    args: argparse.argparse.Namespace,\n) -> Tuple[Dict[str, Any], List[RequestRecord]]:\n    \"\"\"Run the pipeline with the given dataset and args. Return the benchmark report dict.\"\"\"\n    random.seed(args.seed)\n    np.random.seed(args.seed)\n    request_records = dataset.generate_request_records(\n        args.input_len,\n        args.output_len,\n        args.input_len_std,\n        args.output_len_std,\n    )\n    request_records = pipeline(request_records)\n    num_total_requests = (\n        args.num_requests if not args.per_gpu_workload else args.num_requests * args.num_gpus\n    )\n    assert len(request_records) == num_total_requests\n    sorted_requests: List[RequestRecord] = [None] * num_total_requests\n    for request_record in request_records:\n        assert request_record.request_id is not None\n        assert sorted_requests[request_record.request_id] is None\n        sorted_requests[request_record.request_id] = request_record\n\n    request_records = MetricAnalyzer(tokenizer)(request_records)\n    report = generate_metrics_summary(request_records, num_total_requests, args.num_gpus)\n    return report, sorted_requests\n\n\ndef query_mlc_server_metrics(host: str, port: int):\n    \"\"\"Try to get the MLC server metrics whenever it exists.\"\"\"\n    try:\n        r = requests.post(f\"http://{host}:{port}/debug/dump_engine_metrics\", json={}, timeout=10)\n        if r.status_code == 200:\n            print(f\"MLC server metrics: {r.json()}\")\n    except Exception:  # pylint: disable=broad-exception-caught\n        pass\n\n\ndef main(args: argparse.argparse.Namespace):\n    \"\"\"Main benchmark entrance.\"\"\"\n    mlc_server = None\n    if args.mlc_model_lib:\n        mlc_server = _launch_mlc_server(args)\n    if args.num_requests <= 0:\n        raise ValueError(\"Number of requests to benchmark must be positive.\")\n\n    def _main():\n        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)\n        dataset = create_dataset(args, tokenizer)\n        f_create_api_endpoint = functools.partial(create_api_endpoint, args)\n        pipelines = create_pipelines(args, f_create_api_endpoint, dataset)\n        reports = []\n        alltime_records = {}\n        for i, pipeline in enumerate(pipelines):\n            report, request_records = run_pipeline(pipeline, dataset, tokenizer, args)\n            exec_feature = (\n                json.dumps(report[\"exec_feature\"])\n                if report[\"exec_feature\"] is not None\n                else f\"pipeline{i}\"\n            )\n            alltime_records[exec_feature] = [\n                request_record.model_dump() for request_record in request_records\n            ]\n            reports.append(report)\n            pretty_print_report(report)\n        query_mlc_server_metrics(args.host, args.port)\n\n        # Construct data frame\n        df = convert_reports_to_df(reports)\n        print(df)\n        df.to_csv(args.output, index=False)\n        logger.info(\"Benchmark results dumped to file %s\", args.output)\n        if args.debug_dump:\n            debug_dump_filepath = (\n                args.output[:-4] if args.output.endswith(\".csv\") else args.output\n            ) + \"_debug_dump.log\"\n            with open(debug_dump_filepath, \"w\", encoding=\"utf-8\") as file:\n                json.dump(alltime_records, file, indent=4)\n            logger.info(\"Debug log dumped to file %s\", debug_dump_filepath)\n\n    if mlc_server is not None:\n        with mlc_server:\n            _main()\n    else:\n        _main()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"MLC LLM benchmark\")\n\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        choices=SUPPORTED_DATASET,\n        help=f\"The benchmark dataset kind. Supporting {SUPPORTED_DATASET}\",\n    )\n    parser.add_argument(\n        \"--dataset-path\",\n        type=str,\n        help=\"The dataset file path.\",\n    )\n    parser.add_argument(\n        \"--api-endpoint\",\n        type=str,\n        choices=SUPPORTED_BACKENDS,\n        default=\"openai\",\n        help=\"The API endpoint API for benchmarking.\",\n    )\n    parser.add_argument(\n        \"--tokenizer\",\n        type=str,\n        required=True,\n        help=\"The path of the tokenizer directory.\",\n    )\n    parser.add_argument(\n        \"--num-gpus\",\n        type=int,\n        required=True,\n        help=\"The number of GPUs used by the server. \"\n        \"We need this to better analyze the throughput per GPU.\",\n    )\n    parser.add_argument(\n        \"--num-requests\",\n        type=int,\n        required=True,\n        help=\"The number of requests for benchmark.\",\n    )\n    parser.add_argument(\n        \"--num-warmup-requests\",\n        type=int,\n        help=\"The number of requests for warmup. \"\n        \"It is optional when fixing the number of concurrent requests, and is required otherwise.\",\n    )\n    parser.add_argument(\n        \"--per-gpu-workload\",\n        default=False,\n        action=\"store_true\",\n        help='When set to True, the specified \"num_concurrent_requests\"/\"request_rate\" '\n        \"denote the workload **per GPU**, which means that the real values of \"\n        '\"num_concurrent_requests\"/\"request_rate\" used in benchmark'\n        'will be multiplied by \"num_gpus\".',\n    )\n    parser.add_argument(\n        \"--num-concurrent-requests\",\n        type=_parse_num_concurrent_requests,\n        help=\"The number(s) of concurrent requests to benchmark. \"\n        'It can be either one integer or a list of integer separated by commas(\",\"). '\n        \"When specified, for each integer, the benchmark keeps these many consistent \"\n        \"number of concurrently running requests.\",\n    )\n    parser.add_argument(\n        \"--request-rate\",\n        type=_parse_request_rate,\n        help=\"The request rate(s) denoting the number of new requests each second. \"\n        'It can be either one float number (or \"inf\") or a list of numbers separated '\n        'by commas(\",\"). '\n        \"When specified, the benchmark sends these many new requests each second. \"\n        'If it is \"inf\", all requests will be sent together at once.',\n    )\n    parser.add_argument(\n        \"--replay-timestamp-scale\",\n        type=float,\n        help=\"The timestamp scale when replaying the timestamps in a dataset. \"\n        'The dataset replay mode is enabled when neither \"--num-concurrent-requests\" and '\n        '\"--request-rate\" is specified. '\n        \"The scale is 1 by default in the replay mode.\",\n    )\n    parser.add_argument(\n        \"--input-len\",\n        type=int,\n        help=\"The benchmark request average input length. Default to None, \"\n        \"which means the request input length depends on the dataset being used.\",\n    )\n    parser.add_argument(\n        \"--input-len-std\",\n        type=float,\n        default=0,\n        help=\"The benchmark request input length standard deviation. Default to 0.\",\n    )\n    parser.add_argument(\n        \"--output-len\",\n        type=int,\n        help=\"The benchmark request average output length. Default to None, \"\n        \"which means the request output length depends on the dataset being used.\",\n    )\n    parser.add_argument(\n        \"--output-len-std\",\n        type=float,\n        default=0,\n        help=\"The benchmark request output length standard deviation. Default to 0.\",\n    )\n    parser.add_argument(\n        \"--stream\",\n        type=bool,\n        default=True,\n        help=\"Whether to benchmark stream responses. \"\n        \"When not enabled, metrics such as time-to-first-token (TTFT) will not be available. \"\n        \"Default to True.\",\n    )\n    parser.add_argument(\n        # NOTE: The current implementation of server metrics still has some issues that need fixes,\n        # which makes it not work to include server metrics.\n        \"--include-server-metrics\",\n        action=\"store_true\",\n        help=\"Whether to also benchmark the server side request metrics. \"\n        \"This option is only available when benchmarking MLC server.\",\n    )\n    parser.add_argument(\n        \"--host\",\n        type=str,\n        required=True,\n        help=\"The host address of the backend API.\",\n    )\n    parser.add_argument(\n        \"--port\",\n        type=int,\n        required=True,\n        help=\"The port of the backend API.\",\n    )\n    parser.add_argument(\n        \"--timeout\",\n        type=float,\n        default=3 * 60 * 60,\n        help=\"The timeout limit of each request.\",\n    )\n    parser.add_argument(\n        \"--seed\",\n        type=int,\n        default=0,\n        help=\"The random number seed. Default to 0.\",\n    )\n    parser.add_argument(\n        \"--temperature\",\n        type=float,\n        default=1.0,\n        help=\"The temperature value for logit adjustment. Default to 1.\",\n    )\n    parser.add_argument(\n        \"--top-p\",\n        type=float,\n        default=1.0,\n        help=\"The top-p value for sampling. Default to 1.\",\n    )\n    parser.add_argument(\n        \"--ignore-eos\",\n        default=False,\n        action=\"store_true\",\n        help='Whether to set the \"ignore_eos\" field.',\n    )\n    parser.add_argument(\n        \"--apply-chat-template\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to apply chat template to the request input text. \"\n        'It is not supported when \"--input-len\" is specified.',\n    )\n    parser.add_argument(\n        \"--num-process-workers\",\n        type=int,\n        help=\"The number of parallel process workers to send the requests.\",\n    )\n    parser.add_argument(\n        \"--disable-tqdm\",\n        action=\"store_true\",\n        help=\"Whether to disable showing progress bar with tqdm during benchmarking.\",\n    )\n    parser.add_argument(\n        \"--max-schedule-gap\",\n        type=float,\n        default=0.5,\n        help=\"The maximum allowed delay between the scheduled time in seconds.\",\n    )\n    parser.add_argument(\n        \"--mlc-model-lib\",\n        type=str,\n        help=\"The model lib path when benchmarking MLC serve. \"\n        \"When specified, the server is automatic launched and no external server launch is needed.\",\n    )\n    parser.add_argument(\n        \"--mlc-engine-config\",\n        type=_parse_mlc_engine_config,\n        help=\"The engine config used when launch MLC server.\",\n    )\n    parser.add_argument(\n        \"--cuda-profile\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to enable cuda profile on server. \"\n        \"The --mlc-model-lib path should be provided when enabling this option.\",\n    )\n    parser.add_argument(\n        \"--debug-dump\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to dump all request record raw data to file.\",\n    )\n    parser.add_argument(\n        \"--multi-round\",\n        default=False,\n        action=\"store_true\",\n        help=\"Whether to chat like multi round conversion with history log each request. \"\n        \"Only enabled when benchmarked with fixed concurrent request mode.\"\n        \"The --num-concurrent-requests should be provided when enabling this option.\",\n    )\n    parser.add_argument(\n        \"--output\",\n        \"-o\",\n        type=str,\n        default=\"mlc_benchmark.csv\",\n        help=\"The path of the output file where to dump the benchmark results.\",\n    )\n\n    main(parser.parse_args())\n"
  },
  {
    "path": "python/mlc_llm/bench/api_endpoint.py",
    "content": "\"\"\"MLC LLM bench backends\"\"\"\n\nimport argparse\nimport json\nimport os\nimport time\nimport traceback\nfrom typing import Optional\n\nfrom typing_extensions import Self\n\nfrom mlc_llm.bench.request_record import Metrics, RequestRecord, ServerMetrics\nfrom mlc_llm.support import logging\n\nlogger = logging.getLogger(__name__)\n\n\nclass APIEndPoint:\n    \"\"\"Manages the sending of requests to a specified API endpoint and gathers\n    inference statistics.\n    \"\"\"\n\n    def __init__(self, include_server_metrics: bool = False) -> None:\n        self.include_server_metrics = include_server_metrics\n\n    async def __aenter__(self) -> Self:\n        return self\n\n    async def __aexit__(self, exc_type, exc_value, tb) -> None:\n        pass\n\n    async def __call__(self, request: RequestRecord) -> RequestRecord:\n        raise NotImplementedError()\n\n\nclass OpenAIChatEndPoint(APIEndPoint):\n    \"\"\"The backend of sending HTTP requests in OpenAI API through \"v1/chat/completions\".\"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments\n        self,\n        host: str,\n        port: int,\n        timeout: Optional[float] = None,\n        include_server_metrics: bool = False,\n    ) -> None:\n        super().__init__(include_server_metrics=include_server_metrics)\n\n        import aiohttp  # pylint: disable=import-outside-toplevel,import-error\n\n        self.timeout = timeout\n        self.client: aiohttp.ClientSession = None\n        self.url = f\"http://{host}:{port}/v1/chat/completions\"\n        self.headers = {\"Content-Type\": \"application/json\"}\n        if os.getenv(\"MLC_LLM_API_KEY\"):\n            self.headers[\"Authorization\"] = f\"Bearer {os.getenv('MLC_LLM_API_KEY')}\"\n\n    async def __aenter__(self) -> Self:\n        import aiohttp  # pylint: disable=import-outside-toplevel,import-error\n\n        self.client = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(self.timeout))\n        return self\n\n    async def __aexit__(self, exc_type, exc_value, tb) -> None:\n        await self.client.close()\n\n    async def __call__(  # pylint: disable=too-many-branches,too-many-statements,too-many-locals\n        self, request_record: RequestRecord\n    ) -> RequestRecord:\n        payload = request_record.chat_cmpl.model_dump()\n        if self.timeout is not None and \"timeout\" not in payload:\n            payload[\"timeout\"] = self.timeout\n        if self.include_server_metrics:\n            if \"stream_options\" not in payload or payload[\"stream_options\"] is None:\n                payload[\"stream_options\"] = {\"include_usage\": True}\n            else:\n                payload[\"stream_options\"][\"include_usage\"] = True\n        if (\n            request_record.chat_cmpl.debug_config is not None\n            and request_record.chat_cmpl.debug_config.ignore_eos\n        ):\n            payload[\"ignore_eos\"] = True\n\n        generated_text = \"\"\n        first_chunk_output_str = \"\"\n        time_to_first_token_s = None\n        start_time = time.monotonic()\n        server_metrics = None\n\n        try:\n            async with self.client.post(self.url, json=payload, headers=self.headers) as response:\n                assert response.status == 200, await response.text()\n                if payload[\"stream\"]:\n                    async for chunk in response.content:\n                        chunk = chunk.strip()\n                        if not chunk or chunk == b\"\\n\":\n                            continue\n                        # Get rid of the prefix \"data: \" and suffix \"\\n\"\n                        raw_data = chunk[6:].strip()\n                        if raw_data == b\"[DONE]\":\n                            continue\n                        data = json.loads(raw_data)\n                        if not data[\"choices\"]:\n                            continue\n                        delta = data[\"choices\"][0][\"delta\"]\n                        content = delta.get(\"content\", None)\n                        if content is not None and not time_to_first_token_s:\n                            time_to_first_token_s = time.monotonic() - start_time\n                            first_chunk_output_str = content\n                        if self.include_server_metrics and data[\"usage\"] is not None:\n                            # fmt: off\n                            # pylint: disable=line-too-long\n                            server_metrics = ServerMetrics(\n                                input_tokens=data[\"usage\"][\"extra\"][\"prompt_tokens\"],\n                                prefill_tokens=data[\"usage\"][\"extra\"][\"prefill_tokens\"],\n                                output_tokens=data[\"usage\"][\"extra\"][\"completion_tokens\"],\n                                end_to_end_latency_s=data[\"usage\"][\"extra\"][\"end_to_end_latency_s\"],\n                                prefill_tokens_per_s=data[\"usage\"][\"extra\"][\"prefill_tokens_per_s\"],\n                                inter_token_latency_s=data[\"usage\"][\"extra\"][\"inter_token_latency_s\"],\n                                time_per_output_token_s=1 / data[\"usage\"][\"extra\"][\"decode_tokens_per_s\"],\n                                time_to_first_token_s=data[\"usage\"][\"extra\"][\"ttft_s\"],\n                            )\n                            # pylint: enable=line-too-long\n                            # fmt: on\n\n                        if content is not None:\n                            generated_text += content\n                else:\n                    data = await response.json()\n                    generated_text = data[\"choices\"][0][\"message\"][\"content\"]\n                    if self.include_server_metrics and data[\"usage\"] is not None:\n                        # fmt: off\n                        # pylint: disable=line-too-long\n                        server_metrics = ServerMetrics(\n                            input_tokens=data[\"usage\"][\"extra\"][\"prompt_tokens\"],\n                            prefill_tokens=data[\"usage\"][\"extra\"][\"prefill_tokens\"],\n                            output_tokens=data[\"usage\"][\"extra\"][\"completion_tokens\"],\n                            end_to_end_latency_s=data[\"usage\"][\"extra\"][\"end_to_end_latency_s\"],\n                            prefill_tokens_per_s=data[\"usage\"][\"extra\"][\"prefill_tokens_per_s\"],\n                            inter_token_latency_s=data[\"usage\"][\"extra\"][\"inter_token_latency_s\"],\n                            time_per_output_token_s=1 / data[\"usage\"][\"extra\"][\"decode_tokens_per_s\"],\n                            time_to_first_token_s=data[\"usage\"][\"extra\"][\"ttft_s\"],\n                        )\n                        # pylint: enable=line-too-long\n                        # fmt: on\n        except Exception:  # pylint: disable=broad-except\n            error_msg = \"API endpoint errored when sending request: \" + traceback.format_exc()\n            logger.info(error_msg)\n            finish_time = time.monotonic()\n            request_record.output_str = generated_text\n            request_record.first_chunk_output_str = first_chunk_output_str\n            request_record.metrics = Metrics(\n                success=False,\n                start_time=start_time,\n                finish_time=finish_time,\n                end_to_end_latency_s=finish_time - start_time,\n                input_tokens=request_record.metrics.input_tokens,\n                time_to_first_token_s=time_to_first_token_s,\n                server_metrics=server_metrics,\n                exec_feature=request_record.metrics.exec_feature,\n            )\n            request_record.error_msg = error_msg\n            return request_record\n\n        finish_time = time.monotonic()\n        request_record.output_str = generated_text\n        request_record.first_chunk_output_str = first_chunk_output_str\n        success = True\n        error_msg = None\n        if len(generated_text) == 0:\n            success = False\n            error_msg = \"Empty generated text.\"\n        request_record.metrics = Metrics(\n            success=success,\n            start_time=start_time,\n            finish_time=finish_time,\n            end_to_end_latency_s=finish_time - start_time,\n            input_tokens=request_record.metrics.input_tokens,\n            time_to_first_token_s=time_to_first_token_s,\n            server_metrics=server_metrics,\n            exec_feature=request_record.metrics.exec_feature,\n        )\n        request_record.error_msg = error_msg\n        return request_record\n\n\nclass OpenAIEndPoint(APIEndPoint):\n    \"\"\"The backend of sending HTTP requests in OpenAI API through \"v1/completions\".\"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments\n        self,\n        host: str,\n        port: int,\n        timeout: Optional[float] = None,\n        include_server_metrics: bool = False,\n        no_debug_config: bool = False,\n    ) -> None:\n        super().__init__(include_server_metrics=include_server_metrics)\n\n        import aiohttp  # pylint: disable=import-outside-toplevel,import-error\n\n        self.timeout = timeout\n        self.client: aiohttp.ClientSession = None\n        self.url = f\"http://{host}:{port}/v1/completions\"\n        self.headers = {\"Content-Type\": \"application/json\"}\n        if os.getenv(\"MLC_LLM_API_KEY\"):\n            self.headers[\"Authorization\"] = f\"Bearer {os.getenv('MLC_LLM_API_KEY')}\"\n        assert (\n            not include_server_metrics\n        ), '\"include_server_metrics\" only works for \"openai-chat\" endpoint for now'\n        self.no_debug_config = no_debug_config\n\n    async def __aenter__(self) -> Self:\n        import aiohttp  # pylint: disable=import-outside-toplevel,import-error\n\n        self.client = aiohttp.ClientSession()\n        return self\n\n    async def __aexit__(self, exc_type, exc_value, tb) -> None:\n        await self.client.close()\n\n    async def __call__(  # pylint: disable=too-many-branches,too-many-statements\n        self, request_record: RequestRecord\n    ) -> RequestRecord:\n        assert (\n            len(request_record.chat_cmpl.messages) == 1\n        ), 'Endpoint \"openai\" does not support system prompt and multi-round conversation.'\n        assert isinstance(request_record.chat_cmpl.messages[0].content, str)\n        payload = {\n            \"model\": request_record.chat_cmpl.model,\n            \"prompt\": request_record.chat_cmpl.messages[0].content,\n            \"temperature\": request_record.chat_cmpl.temperature,\n            \"top_p\": request_record.chat_cmpl.top_p,\n            \"max_tokens\": request_record.chat_cmpl.max_tokens,\n            \"stream\": True,\n        }\n        if self.timeout is not None and \"timeout\" not in payload:\n            payload[\"timeout\"] = self.timeout\n        if (\n            request_record.chat_cmpl.debug_config is not None\n            and request_record.chat_cmpl.debug_config.ignore_eos\n        ):\n            payload[\"ignore_eos\"] = True\n            if not self.no_debug_config:\n                payload[\"debug_config\"] = {\"ignore_eos\": True}\n\n        generated_text = \"\"\n        first_chunk_output_str = \"\"\n        time_to_first_token_s = None\n        start_time = time.monotonic()\n\n        try:\n            async with self.client.post(\n                self.url, json=payload, headers=self.headers, timeout=3600\n            ) as response:\n                assert response.status == 200, await response.text()\n                if payload[\"stream\"]:\n                    async for chunk in response.content:\n                        chunk = chunk.strip()\n                        if not chunk or chunk == b\"\\n\":\n                            continue\n                        # Get rid of the prefix \"data: \" and suffix \"\\n\"\n                        raw_data = chunk[6:].strip()\n                        if raw_data == b\"[DONE]\":\n                            continue\n                        data = json.loads(raw_data)\n                        if not data[\"choices\"]:\n                            continue\n                        content = data[\"choices\"][0][\"text\"]\n                        if content is not None and not time_to_first_token_s:\n                            time_to_first_token_s = time.monotonic() - start_time\n                            first_chunk_output_str = content\n                        if content is not None:\n                            generated_text += content\n                else:\n                    data = await response.json()\n                    generated_text = data[\"choices\"][0][\"message\"][\"content\"]\n        except Exception:  # pylint: disable=broad-except\n            error_msg = \"API endpoint errored when sending request: \" + traceback.format_exc()\n            logger.info(error_msg)\n            finish_time = time.monotonic()\n            request_record.output_str = generated_text\n            request_record.first_chunk_output_str = first_chunk_output_str\n            request_record.metrics = Metrics(\n                success=False,\n                start_time=start_time,\n                finish_time=finish_time,\n                end_to_end_latency_s=finish_time - start_time,\n                input_tokens=request_record.metrics.input_tokens,\n                time_to_first_token_s=time_to_first_token_s,\n                server_metrics=None,\n                exec_feature=request_record.metrics.exec_feature,\n            )\n            request_record.error_msg = error_msg\n            return request_record\n\n        finish_time = time.monotonic()\n        request_record.output_str = generated_text\n        request_record.first_chunk_output_str = first_chunk_output_str\n        success = True\n        error_msg = None\n        if len(generated_text) == 0:\n            success = False\n            error_msg = \"Empty generated text.\"\n        request_record.metrics = Metrics(\n            success=success,\n            start_time=start_time,\n            finish_time=finish_time,\n            end_to_end_latency_s=finish_time - start_time,\n            input_tokens=request_record.metrics.input_tokens,\n            time_to_first_token_s=time_to_first_token_s,\n            server_metrics=None,\n            exec_feature=request_record.metrics.exec_feature,\n        )\n        request_record.error_msg = error_msg\n        return request_record\n\n\nclass TensorRTLLMEndPoint(APIEndPoint):\n    \"\"\"The backend of sending HTTP requests in TensorRT-LLM API.\"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments\n        self, host: str, port: int, timeout: Optional[float] = None\n    ) -> None:\n        super().__init__(include_server_metrics=False)\n\n        import aiohttp  # pylint: disable=import-outside-toplevel,import-error\n\n        self.timeout = timeout\n        self.client: aiohttp.ClientSession = None\n        self.url_stream = f\"http://{host}:{port}/v2/models/ensemble/generate_stream\"\n        self.url_no_stream = f\"http://{host}:{port}/v2/models/ensemble/generate\"\n\n    async def __aenter__(self) -> Self:\n        import aiohttp  # pylint: disable=import-outside-toplevel,import-error\n\n        self.client = aiohttp.ClientSession()\n        return self\n\n    async def __aexit__(self, exc_type, exc_value, tb) -> None:\n        await self.client.close()\n\n    async def __call__(  # pylint: disable=too-many-branches,too-many-locals,too-many-statements\n        self, request_record: RequestRecord\n    ) -> RequestRecord:\n        assert len(request_record.chat_cmpl.messages) == 1\n        assert isinstance(request_record.chat_cmpl.messages[0].content, str)\n        payload = {\n            \"accumulate_tokens\": True,\n            \"text_input\": request_record.chat_cmpl.messages[0].content,\n            \"temperature\": (\n                max(request_record.chat_cmpl.temperature, 1e-5)\n                if request_record.chat_cmpl.temperature\n                else 1\n            ),\n            \"top_p\": request_record.chat_cmpl.top_p if request_record.chat_cmpl.top_p else 1,\n            \"max_tokens\": request_record.chat_cmpl.max_tokens,\n            \"stream\": request_record.chat_cmpl.stream,\n        }\n        if (\n            request_record.chat_cmpl.debug_config is not None\n            and request_record.chat_cmpl.debug_config.ignore_eos\n        ):\n            payload[\"min_length\"] = payload[\"max_tokens\"]\n        if self.timeout is not None and \"timeout\" not in payload:\n            payload[\"timeout\"] = self.timeout\n\n        generated_text = \"\"\n        first_chunk_output_str = \"\"\n        url = self.url_stream if request_record.chat_cmpl.stream else self.url_no_stream\n        time_to_first_token_s = None\n        start_time = time.monotonic()\n\n        try:\n            async with self.client.post(url, json=payload) as response:\n                assert response.status == 200, await response.text()\n                if payload[\"stream\"]:\n                    async for chunk in response.content:\n                        chunk = chunk.strip()\n                        if not chunk or chunk == b\"\\n\":\n                            continue\n                        # Get rid of the prefix \"data:\" and suffix \"\\n\"\n                        raw_data = chunk[5:].strip()\n                        data = json.loads(raw_data)\n                        delta = data[\"text_output\"]\n                        if delta is None:\n                            continue\n\n                        if not time_to_first_token_s:\n                            time_to_first_token_s = time.monotonic() - start_time\n                            first_chunk_output_str = delta\n                        generated_text += delta\n                else:\n                    data = await response.json()\n                    generated_text = data[\"text_output\"]\n        except Exception:  # pylint: disable=broad-except\n            error_msg = \"API endpoint errored when sending request: \" + traceback.format_exc()\n            logger.info(error_msg)\n            finish_time = time.monotonic()\n            request_record.output_str = generated_text\n            request_record.first_chunk_output_str = first_chunk_output_str\n            request_record.metrics = Metrics(\n                success=False,\n                start_time=start_time,\n                finish_time=finish_time,\n                end_to_end_latency_s=finish_time - start_time,\n                input_tokens=request_record.metrics.input_tokens,\n                time_to_first_token_s=time_to_first_token_s,\n                exec_feature=request_record.metrics.exec_feature,\n            )\n            request_record.error_msg = error_msg\n            return request_record\n\n        finish_time = time.monotonic()\n        request_record.output_str = generated_text\n        request_record.first_chunk_output_str = first_chunk_output_str\n        success = True\n        error_msg = None\n        if len(generated_text) == 0:\n            success = False\n            error_msg = \"Empty generated text.\"\n        request_record.metrics = Metrics(\n            success=success,\n            start_time=start_time,\n            finish_time=finish_time,\n            end_to_end_latency_s=finish_time - start_time,\n            input_tokens=request_record.metrics.input_tokens,\n            time_to_first_token_s=time_to_first_token_s,\n            exec_feature=request_record.metrics.exec_feature,\n        )\n        request_record.error_msg = error_msg\n        return request_record\n\n\n# Todo: APIEndPoint with AsyncOpenAI Python interface  # pylint: disable=fixme\n# class OpenAIPythonEndPoint(APIEndPoint):\n#     pass\n\nSUPPORTED_BACKENDS = [\n    \"openai\",\n    \"openai-chat\",\n    \"mlc\",\n    \"sglang\",\n    \"tensorrt-llm\",\n    \"vllm\",\n]\n\n\ndef create_api_endpoint(args: argparse.Namespace) -> APIEndPoint:\n    \"\"\"Create an API endpoint instance with regard to the specified endpoint kind.\"\"\"\n    if args.api_endpoint in [\"openai\", \"mlc\", \"sglang\"]:\n        return OpenAIEndPoint(args.host, args.port, args.timeout, args.include_server_metrics)\n    if args.api_endpoint == \"vllm\":\n        return OpenAIEndPoint(\n            args.host,\n            args.port,\n            args.timeout,\n            include_server_metrics=False,\n            no_debug_config=True,\n        )\n    if args.api_endpoint == \"openai-chat\":\n        return OpenAIChatEndPoint(args.host, args.port, args.timeout, args.include_server_metrics)\n    if args.api_endpoint == \"tensorrt-llm\":\n        return TensorRTLLMEndPoint(args.host, args.port, args.timeout)\n    raise ValueError(f'Unrecognized endpoint \"{args.api_endpoint}\"')\n"
  },
  {
    "path": "python/mlc_llm/bench/dataset.py",
    "content": "\"\"\"MLC LLM benchmark dataset classes\"\"\"\n\nimport argparse\nimport json\nimport random\nfrom datetime import datetime\nfrom typing import Dict, List, Optional, Tuple\n\nimport numpy as np\nimport pandas as pd  # pylint: disable=import-error\nfrom datasets import load_dataset  # pylint: disable=import-error\nfrom transformers import AutoTokenizer  # pylint: disable=import-error\n\nfrom mlc_llm.bench.request_record import GroupedRequestRecord, Metrics, RequestRecord\nfrom mlc_llm.protocol.openai_api_protocol import (\n    ChatCompletionMessage,\n    ChatCompletionRequest,\n    DebugConfig,\n)\n\n\nclass Dataset:  # pylint: disable=too-few-public-methods\n    \"\"\"The dataset base class.\"\"\"\n\n    # We set a truncation limit of 100k.\n    truncate_length = int(1e5)\n    # For some that datasets (e.g., dataset that has shared common prefix),\n    # we need fake warmup requests to avoid prefilling common prefixes to the engine.\n    require_fake_warmup: bool = False\n    # Whether the dataset contains timestamps already.\n    # If the dataset comes with timestamps, the benchmark can just replay\n    # the requests according to their timestamps.\n    timestamp_available: bool = False\n\n    def generate_request_records(\n        self,\n        input_len: Optional[int],\n        output_len: Optional[int],\n        input_len_std: float = 0.0,\n        output_len_std: float = 0.0,\n    ) -> List[RequestRecord]:\n        \"\"\"Get the raw unprocessed request records of the dataset.\"\"\"\n        raise NotImplementedError()\n\n\nclass ShareGPTDataset(Dataset):  # pylint: disable=too-few-public-methods\n    \"\"\"The dataset class for ShareGPT dataset.\"\"\"\n\n    _tokenized_dataset: List[Tuple[str, List[int], int]]\n    apply_chat_template: bool\n\n    def __init__(\n        self, dataset_path: str, tokenizer: AutoTokenizer, apply_chat_template: bool\n    ) -> None:\n        self.apply_chat_template = apply_chat_template\n        with open(dataset_path, encoding=\"utf-8\") as f:\n            raw_dataset = json.load(f)\n        # Filter out the conversations with less than 2 turns.\n        _dataset = [\n            (data[\"conversations\"][0][\"value\"], data[\"conversations\"][1][\"value\"])\n            for data in raw_dataset\n            if len(data[\"conversations\"]) >= 2 and data[\"conversations\"][0][\"from\"] == \"human\"\n        ]\n        # Tokenize the prompts and completions.\n        self.tokenizer = tokenizer\n        prompts = [prompt for prompt, _ in _dataset]\n        if apply_chat_template:\n            assert (\n                getattr(tokenizer, \"chat_template\", None) is not None\n            ), '\"--apply-chat-template\" is set but the tokenizer does not have chat template.'\n            prompts = [\n                tokenizer.apply_chat_template(\n                    [{\"role\": \"user\", \"content\": prompt}],\n                    add_generation_prompt=True,\n                    tokenize=False,\n                )\n                for prompt in prompts\n            ]\n\n        prompt_token_ids = list(\n            tokenizer(\n                prompts,\n                truncation=True,\n                max_length=min(tokenizer.model_max_length, self.truncate_length),\n                add_special_tokens=False,\n            ).input_ids\n        )\n        completions = [completion for _, completion in _dataset]\n        completion_token_ids = tokenizer(\n            completions,\n            truncation=True,\n            max_length=min(tokenizer.model_max_length, self.truncate_length),\n            add_special_tokens=False,\n        ).input_ids\n        self._tokenized_dataset: List[Tuple[str, List[int], int]] = []\n        for i in range(len(_dataset)):\n            if (\n                len(prompt_token_ids[i]) < 4\n                or len(completion_token_ids[i]) < 4\n                or len(prompt_token_ids[i]) + len(completion_token_ids[i])\n                >= min(tokenizer.model_max_length, 8192)\n            ):\n                # Filter out sequences that are too short or too long\n                continue\n            self._tokenized_dataset.append(\n                (prompts[i], prompt_token_ids[i], len(completion_token_ids[i]))\n            )\n\n    def generate_request_records(\n        self,\n        input_len: Optional[int],\n        output_len: Optional[int],\n        input_len_std: float = 0.0,\n        output_len_std: float = 0.0,\n    ) -> List[RequestRecord]:\n        if self.apply_chat_template:\n            assert (\n                input_len is None\n            ), '\"--apply-chat-template\" is not supported when \"--input-len\" is specified.'\n\n        request_records = []\n        for prompt, input_token_ids, output_length in self._tokenized_dataset:\n            input_length = len(input_token_ids)\n            # If the request does not have enough length, discard it.\n            if input_len is not None and input_length < input_len + 4 * input_len_std:\n                continue\n\n            if input_len is not None:\n                input_length = round(\n                    float(np.random.normal(loc=input_len, scale=input_len_std, size=1)[0])\n                )\n                input_token_ids = input_token_ids[:input_length]\n                input_truncated = True\n            else:\n                input_truncated = False\n            if output_len is not None:\n                output_length = round(\n                    float(np.random.normal(loc=output_len, scale=output_len_std, size=1)[0])\n                )\n            elif output_length <= 1:\n                continue\n            request_records.append(\n                RequestRecord(\n                    chat_cmpl=ChatCompletionRequest(\n                        messages=[\n                            {\n                                \"role\": \"user\",\n                                \"content\": (\n                                    self.tokenizer.decode(input_token_ids)\n                                    if input_truncated\n                                    else prompt\n                                ),\n                            }\n                        ],\n                        model=\"\",\n                        max_tokens=output_length,\n                    ),\n                    metrics=Metrics(\n                        success=False,\n                        start_time=0,\n                        finish_time=0,\n                        end_to_end_latency_s=0,\n                        input_tokens=len(input_token_ids),\n                    ),\n                )\n            )\n        return request_records\n\n\nclass LoogleDataset(Dataset):  # pylint: disable=too-few-public-methods\n    \"\"\"The dataset class for Loogle dataset.\"\"\"\n\n    # pylint: disable=line-too-long\n    task2prompt = {\n        \"shortdep_qa\": \"Please answer the question based on the long texts below. \\n{input}\\nQuestion: {Q}\\nAnswer: \",\n        \"longdep_qa\": \"Please answer the question based on the long texts below. \\n{input}\\nQuestion: {Q}\\nAnswer: \",\n        \"longdep_summarization\": \"Please generate a summary of the below paper. \\n{input}\\n Summarization: \",\n        \"shortdep_cloze\": \"Please fill in the clozes based on the given long texts below. Each of the placeholder '<mask-n>' in the question could be an entity of Person, Location or Organiocation. The same masks represent the same entity. Output a json format answer, for example: {{'<mask-0>': 'Bob', '<mask-1>': 'Gorrosion Magazine','<mask-2>': 'Bethel Horizon'}}\\n{input}\\n Question: {Q} What are the masked entities? \\nAnswer:\",\n    }\n    # pylint: enable=line-too-long\n    require_fake_warmup: bool = True\n\n    def __init__(self, tokenizer: AutoTokenizer, testset_name: str) -> None:\n        raw_dataset = load_dataset(\"bigainlco/LooGLE\", testset_name, split=\"test\")\n        self.tokenizer = tokenizer\n        self.dataset = []\n        self.prompt_format = self.task2prompt[testset_name]\n        prompts = []\n        generate_lens = []\n        questions = []\n        for data in raw_dataset:\n            prompt = data[\"input\"]\n            prompts.append(prompt)\n            qa_pairs = eval(data[\"qa_pairs\"])  # pylint: disable=eval-used\n            questions.append([j[\"Q\"] for j in qa_pairs])\n            generate_lens.append(\n                [len(tokenizer.encode(j[\"A\"], add_special_tokens=False)) for j in qa_pairs]\n            )\n        prompt_token_ids = tokenizer(\n            prompts,\n            truncation=True,\n            max_length=min(tokenizer.model_max_length, self.truncate_length),\n            add_special_tokens=False,\n        ).input_ids\n        for prompt, prompt_token_id, question, generate_len in zip(\n            prompts, prompt_token_ids, questions, generate_lens\n        ):\n            self.dataset.append((prompt, prompt_token_id, question, generate_len))\n\n    def generate_request_records(  # pylint: disable=too-many-locals\n        self,\n        input_len: Optional[int],\n        output_len: Optional[int],\n        input_len_std: float = 0.0,\n        output_len_std: float = 0.0,\n    ) -> List[RequestRecord]:\n        request_records = []\n        for prompt, input_token_ids, questions, generate_lens in self.dataset:\n            input_length = round(float(np.random.normal(loc=input_len, scale=input_len_std)))\n            if len(input_token_ids) > input_length:\n                input_token_ids = input_token_ids[:input_length]\n                prompt = self.tokenizer.decode(input_token_ids)\n            grouped_request_records = []\n            for question, generate_len in zip(questions, generate_lens):\n                json_obj = {\"input\": prompt, \"Q\": question}\n                full_prompt = self.prompt_format.format(**json_obj)\n\n                output_length = (\n                    round(float(np.random.normal(loc=output_len, scale=output_len_std, size=1)[0]))\n                    if output_len is not None\n                    else generate_len\n                )\n                grouped_request_records.append(\n                    RequestRecord(\n                        chat_cmpl=ChatCompletionRequest(\n                            messages=[\n                                {\n                                    \"role\": \"user\",\n                                    \"content\": full_prompt,\n                                }\n                            ],\n                            model=\"\",\n                            max_tokens=output_length,\n                        ),\n                        metrics=Metrics(\n                            success=False,\n                            start_time=0,\n                            finish_time=0,\n                            end_to_end_latency_s=0,\n                            input_tokens=len(input_token_ids),\n                        ),\n                    )\n                )\n            request_records.append(\n                GroupedRequestRecord(\n                    # Create a dummy ChatCompletionRequest.\n                    chat_cmpl=ChatCompletionRequest(messages=[]),\n                    records=grouped_request_records,\n                )\n            )\n        return request_records\n\n\nclass LLMPerfDataset(Dataset):  # pylint: disable=too-few-public-methods\n    \"\"\"The dataset class for LLMPerf dataset.\"\"\"\n\n    def __init__(self, dataset_path: str, num_requests: int, tokenizer: AutoTokenizer) -> None:\n        self.tokenizer = tokenizer\n        self.num_requests = num_requests\n\n        with open(dataset_path, encoding=\"utf-8\") as f:\n            untokenized_data = f.readlines()\n        # Tokenize the prompts and completions.\n        tokenized_data = tokenizer(\n            untokenized_data,\n            truncation=True,\n            max_length=min(tokenizer.model_max_length, self.truncate_length),\n            add_special_tokens=False,\n        ).input_ids\n        tokenized_data_lengths = [len(tokens) for tokens in tokenized_data]\n        self.dataset: List[Tuple[str, List[int], int]] = list(\n            zip(untokenized_data, tokenized_data, tokenized_data_lengths)\n        )\n\n    def generate_request_records(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        input_len: Optional[int] = None,\n        output_len: Optional[int] = None,\n        input_len_std: float = 250,\n        output_len_std: float = 0.0,\n    ) -> List[RequestRecord]:\n        if input_len is None or input_len < 40:\n            input_len = 550\n        if output_len is None:\n            output_len = 150\n\n        request_records = []\n        for _ in range(self.num_requests):\n            input_length = round(float(np.random.normal(loc=input_len, scale=input_len_std)))\n            output_length = round(float(np.random.normal(loc=output_len, scale=output_len_std)))\n\n            prompt = (\n                \"Randomly stream lines from the following text \"\n                f\"with {output_length} output tokens. \"\n                \"Don't generate eos tokens:\\n\\n\"\n            )\n\n            remaining_token_length = input_length - len(\n                self.tokenizer.encode(prompt, add_special_tokens=False)\n            )\n\n            random.shuffle(self.dataset)\n\n            while remaining_token_length > 0:\n                for text, tokens, token_length in self.dataset:\n                    if remaining_token_length < token_length:\n                        prompt += self.tokenizer.decode(tokens[:remaining_token_length])\n                    else:\n                        prompt += text\n\n                    remaining_token_length -= token_length\n                    if remaining_token_length < 0:\n                        break\n\n            request_records.append(\n                RequestRecord(\n                    chat_cmpl=ChatCompletionRequest(\n                        messages=[{\"role\": \"user\", \"content\": prompt}],\n                        model=\"\",\n                        max_tokens=output_length,\n                        debug_config=DebugConfig(ignore_eos=True),\n                    ),\n                    metrics=Metrics(\n                        success=False,\n                        start_time=0,\n                        finish_time=0,\n                        end_to_end_latency_s=0,\n                        input_tokens=input_length,\n                    ),\n                )\n            )\n        return request_records\n\n\nclass JSONModeEvalDataset(Dataset):  # pylint: disable=too-few-public-methods\n    \"\"\"The dataset class for JSON dataset.\"\"\"\n\n    def __init__(self, tokenizer: AutoTokenizer) -> None:\n        raw_dataset = load_dataset(\"NousResearch/json-mode-eval\")\n        self.tokenizer = tokenizer\n        self.dataset = []\n        for data in raw_dataset[\"train\"]:\n            messages = data[\"prompt\"]\n            schema = {\n                \"type\": \"json_object\",\n                \"schema\": data[\"schema\"],\n            }\n            num_tokens = 0\n            for message in messages:\n                num_tokens += len(\n                    self.tokenizer.encode(message[\"content\"], add_special_tokens=False)\n                )\n            self.dataset.append((messages, schema, num_tokens))\n\n    def generate_request_records(\n        self,\n        input_len: Optional[int],\n        output_len: Optional[int],\n        input_len_std: float = 0.0,\n        output_len_std: float = 0.0,\n    ) -> List[RequestRecord]:\n        request_records = []\n        for messages, schema, num_tokens in self.dataset:\n            # If the request does not have enough length, discard it.\n            if input_len is not None and num_tokens < input_len + 4 * input_len_std:\n                continue\n\n            if output_len is not None:\n                output_length = max(\n                    round(np.random.normal(loc=output_len, scale=output_len_std)), 1\n                )\n            else:\n                output_length = None\n            request_records.append(\n                RequestRecord(\n                    chat_cmpl=ChatCompletionRequest(\n                        messages=[\n                            ChatCompletionMessage(content=message[\"content\"], role=message[\"role\"])\n                            for message in messages\n                        ],\n                        model=\"\",\n                        max_tokens=output_length,\n                        response_format=schema,\n                    ),\n                    metrics=Metrics(\n                        success=False,\n                        start_time=0,\n                        finish_time=0,\n                        end_to_end_latency_s=0,\n                        input_tokens=num_tokens,\n                    ),\n                )\n            )\n        return request_records\n\n\nclass ReActDataset(Dataset):  # pylint: disable=too-few-public-methods\n    \"\"\"The dataset class for replaying a given ReAct trace for benchmark purpose.\n    It is not an actual ReAct agent implementation.\n    \"\"\"\n\n    _dataset: List[List[Tuple[str, int, int]]]\n    require_fake_warmup: bool = True\n    # pylint: disable=line-too-long\n    prefix: str = \"\"\"Solve a question answering task with interleaving Thought, Action, Observation steps. Thought can reason about the current situation, and Action can be three types:\n(1) Search[entity], which searches the exact entity on Wikipedia and returns the first paragraph if it exists. If not, it will return some similar entities to search.\n(2) Lookup[keyword], which returns the next sentence containing keyword in the current passage.\n(3) Finish[answer], which returns the answer and finishes the task.\nHere are some examples.\nQuestion: What is the elevation range for the area that the eastern sector of the Colorado orogeny extends into?\nThought 1: I need to search Colorado orogeny, find the area that the eastern sector of the Colorado orogeny extends into, then find the elevation range of the area.\nAction 1: Search[Colorado orogeny]\nObservation 1: The Colorado orogeny was an episode of mountain building (an orogeny) in Colorado and surrounding areas.\nThought 2: It does not mention the eastern sector. So I need to look up eastern sector.\nAction 2: Lookup[eastern sector]\nObservation 2: (Result 1 / 1) The eastern sector extends into the High Plains and is called the Central Plains orogeny.\nThought 3: The eastern sector of Colorado orogeny extends into the High Plains. So I need to search High Plains and find its elevation range.\nAction 3: Search[High Plains]\nObservation 3: High Plains refers to one of two distinct land regions:\nThought 4: I need to instead search High Plains (United States).\nAction 4: Search[High Plains (United States)]\nObservation 4: The High Plains are a subregion of the Great Plains. From east to west, the High Plains rise in elevation from around 1,800 to 7,000 ft (550 to 2,130 m).[3]\nThought 5: High Plains rise in elevation from around 1,800 to 7,000 ft, so the answer is 1,800 to 7,000 ft.\nAction 5: Finish[1,800 to 7,000 ft]\nQuestion: Musician and satirist Allie Goertz wrote a song about the \"The Simpsons\" character Milhouse, who Matt Groening named after who?\nThought 1: The question simplifies to \"The Simpsons\" character Milhouse is named after who. I only need to search Milhouse and find who it is named after.\nAction 1: Search[Milhouse]\nObservation 1: Milhouse Mussolini Van Houten is a recurring character in the Fox animated television series The Simpsons voiced by Pamela Hayden and created by Matt Groening.\nThought 2: The paragraph does not tell who Milhouse is named after, maybe I can look up \"named after\".\nAction 2: Lookup[named after]\nObservation 2: (Result 1 / 1) Milhouse was named after U.S. president Richard Nixon, whose middle name was Milhous.\nThought 3: Milhouse was named after U.S. president Richard Nixon, so the answer is Richard Nixon.\nAction 3: Finish[Richard Nixon]\nQuestion: Which documentary is about Finnish rock groups, Adam Clayton Powell or The Saimaa Gesture?\nThought 1: I need to search Adam Clayton Powell and The Saimaa Gesture, and find which documentary is about Finnish rock groups.\nAction 1: Search[Adam Clayton Powell]\nObservation 1: Could not find [Adam Clayton Powell]. Similar: ['Adam Clayton Powell III', 'Seventh Avenue (Manhattan)', 'Adam Clayton Powell Jr. State Office Building', 'Isabel Washington Powell', 'Adam Powell', 'Adam Clayton Powell (film)', 'Giancarlo Esposito'].\nThought 2: To find the documentary, I can search Adam Clayton Powell (film).\nAction 2: Search[Adam Clayton Powell (film)]\nObservation 2: Adam Clayton Powell is a 1989 American documentary film directed by Richard Kilberg.\nThe film is about the rise and fall of influential African-American politician Adam Clayton Powell Jr.[3][4] It was later aired as part of the PBS series The American Experience.\nThought 3: Adam Clayton Powell (film) is a documentary about an African-American politician, not Finnish rock groups. So the documentary about Finnish rock groups must instead be The Saimaa Gesture.\nAction 3: Finish[The Saimaa Gesture]\nQuestion: What profession does Nicholas Ray and Elia Kazan have in common?\nThought 1: I need to search Nicholas Ray and Elia Kazan, find their professions, then find the profession they have in common.\nAction 1: Search[Nicholas Ray]\nObservation 1: Nicholas Ray (born Raymond Nicholas Kienzle Jr., August 7, 1911 – June 16, 1979) was an American film director, screenwriter, and actor best known for the 1955 film Rebel Without a Cause.\nThought 2: Professions of Nicholas Ray are director, screenwriter, and actor. I need to search Elia Kazan next and find his professions.\nAction 2: Search[Elia Kazan]\nObservation 2: Elia Kazan was an American film and theatre director, producer, screenwriter and actor.\nThought 3: Professions of Elia Kazan are director, producer, screenwriter, and actor. So profession Nicholas Ray and Elia Kazan have in common is director, screenwriter, and actor.\nAction 3: Finish[director, screenwriter, actor]\nQuestion: Which magazine was started first Arthur's Magazine or First for Women?\nThought 1: I need to search Arthur's Magazine and First for Women, and find which was started first.\nAction 1: Search[Arthur's Magazine]\nObservation 1: Arthur's Magazine (1844-1846) was an American literary periodical published in Philadelphia in the 19th century.\nThought 2: Arthur's Magazine was started in 1844. I need to search First for Women next.\nAction 2: Search[First for Women]\nObservation 2: First for Women is a woman's magazine published by Bauer Media Group in the USA.[1] The magazine was started in 1989.\nThought 3: First for Women was started in 1989. 1844 (Arthur's Magazine) < 1989 (First for Women), so Arthur's Magazine was started first.\nAction 3: Finish[Arthur's Magazine]\nQuestion: Were Pavel Urysohn and Leonid Levin known for the same type of work?\nThought 1: I need to search Pavel Urysohn and Leonid Levin, find their types of work, then find if they are the same.\nAction 1: Search[Pavel Urysohn]\nObservation 1: Pavel Samuilovich Urysohn (February 3, 1898 â August 17, 1924) was a Soviet mathematician who is best known for his contributions in dimension theory.\nThought 2: Pavel Urysohn is a mathematician. I need to search Leonid Levin next and find its type of work.\nAction 2: Search[Leonid Levin]\nObservation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist.\nThought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work.\nAction 3: Finish[yes]\n\"\"\"\n\n    # pylint: enable=line-too-long\n    def __init__(  # pylint: disable=too-many-locals\n        self, dataset_path: str, tokenizer: AutoTokenizer\n    ) -> None:\n        raw_entries: List[Dict] = []\n        with open(dataset_path) as fin:  # pylint: disable=unspecified-encoding\n            for line in fin:\n                line_content = json.loads(line)\n                raw_entries += list({\"question\": k, \"triplets\": v} for k, v in line_content.items())\n\n        self._dataset = []\n        max_rounds = 0\n        for raw_entry in raw_entries:\n            processed_entry = []\n            question = raw_entry[\"question\"]\n            triplets = raw_entry[\"triplets\"]\n            seq = self.prefix + question\n            max_rounds = max(max_rounds, len(triplets) + 1)\n            output_lengths: List[int] = []\n            for i, triplet in enumerate(triplets):\n                output_lengths.append(\n                    len(\n                        tokenizer(\n                            triplet[\"thought\"]\n                            + \"\\nAction \"\n                            + str(i + 1)\n                            + \": \"\n                            + triplet[\"action\"]\n                            + \"\\n\",\n                            truncation=True,\n                            max_length=min(tokenizer.model_max_length, self.truncate_length),\n                            add_special_tokens=False,\n                        ).input_ids\n                    )\n                )\n\n            for i in range(1, len(triplets) + 2):\n                seq += \"Thought \" + str(i) + \":\"\n                input_len = len(\n                    tokenizer(\n                        seq,\n                        truncation=True,\n                        max_length=min(tokenizer.model_max_length, self.truncate_length),\n                        add_special_tokens=False,\n                    ).input_ids\n                )\n                output_length = (\n                    output_lengths[i - 1]\n                    if i <= len(triplets)\n                    else int(sum(output_lengths) / len(triplets))\n                )\n                processed_entry.append((seq, input_len, output_length))\n                if i != len(triplets) + 1:\n                    seq += (\n                        triplets[i - 1][\"thought\"]\n                        + \"\\nAction \"\n                        + str(i)\n                        + \": \"\n                        + triplets[i - 1][\"action\"]\n                        + \"\\nObservation \"\n                        + str(i)\n                        + \": \"\n                        + triplets[i - 1][\"observation\"]\n                        + \"\\n\"\n                    )\n            self._dataset.append(processed_entry)\n\n    def generate_request_records(\n        self,\n        input_len: Optional[int],\n        output_len: Optional[int],\n        input_len_std: float = 0.0,\n        output_len_std: float = 0.0,\n    ) -> List[RequestRecord]:\n        if input_len is not None or output_len is not None:\n            raise ValueError(\"ReAct dataset does not support specifying input/output length.\")\n\n        request_records = []\n        for processed_entries in self._dataset:\n            grouped_request_records = []\n            for prompt, input_length, output_length in processed_entries:\n                grouped_request_records.append(\n                    RequestRecord(\n                        chat_cmpl=ChatCompletionRequest(\n                            messages=[{\"role\": \"user\", \"content\": prompt}],\n                            model=\"\",\n                            max_tokens=output_length,\n                        ),\n                        metrics=Metrics(\n                            success=False,\n                            start_time=0,\n                            finish_time=0,\n                            end_to_end_latency_s=0,\n                            input_tokens=input_length,\n                        ),\n                    )\n                )\n            request_records.append(\n                GroupedRequestRecord(\n                    # Create a dummy ChatCompletionRequest.\n                    chat_cmpl=ChatCompletionRequest(messages=[]),\n                    records=grouped_request_records,\n                )\n            )\n        return request_records\n\n\nclass WildChatDataset(Dataset):  # pylint: disable=too-few-public-methods\n    \"\"\"The dataset class for WildChat dataset.\"\"\"\n\n    apply_chat_template: bool\n\n    def __init__(self, tokenizer: AutoTokenizer, apply_chat_template: bool) -> None:\n        raw_dataset = load_dataset(\"allenai/WildChat\", split=\"train\")\n        self.tokenizer = tokenizer\n        self.apply_chat_template = apply_chat_template\n\n        # Filter out the conversations with less than 2 turns.\n        _dataset = [\n            (entry[\"conversation\"][0][\"content\"], entry[\"conversation\"][1][\"content\"])\n            for entry in raw_dataset\n            if len(entry[\"conversation\"]) >= 2\n            and entry[\"conversation\"][0][\"role\"] == \"user\"\n            and entry[\"conversation\"][1][\"role\"] == \"assistant\"\n        ]\n\n        prompts = []\n        completions = []\n        for prompt, completion in _dataset:\n            prompts.append(prompt)\n            completions.append(completion)\n        if apply_chat_template:\n            assert (\n                getattr(tokenizer, \"chat_template\", None) is not None\n            ), '\"--apply-chat-template\" is set but the tokenizer does not have chat template.'\n            prompts = [\n                tokenizer.apply_chat_template(\n                    [{\"role\": \"user\", \"content\": prompt}],\n                    add_generation_prompt=True,\n                    tokenize=False,\n                )\n                for prompt in prompts\n            ]\n\n        prompt_token_ids = list(\n            tokenizer(\n                prompts,\n                truncation=True,\n                max_length=min(tokenizer.model_max_length, self.truncate_length),\n                add_special_tokens=False,\n            ).input_ids\n        )\n        completion_token_ids = tokenizer(\n            completions,\n            truncation=True,\n            max_length=min(tokenizer.model_max_length, self.truncate_length),\n            add_special_tokens=False,\n        ).input_ids\n        self._tokenized_dataset: List[Tuple[str, List[int], int]] = []\n        for i in range(len(_dataset)):\n            if len(prompt_token_ids[i]) < 4 or len(completion_token_ids[i]) < 4:\n                # Filter out sequences that are too short\n                continue\n            self._tokenized_dataset.append(\n                (prompts[i], prompt_token_ids[i], len(completion_token_ids[i]))\n            )\n\n    def generate_request_records(  # pylint: disable=too-many-locals\n        self,\n        input_len: Optional[int],\n        output_len: Optional[int],\n        input_len_std: float = 0.0,\n        output_len_std: float = 0.0,\n    ) -> List[RequestRecord]:\n        if self.apply_chat_template:\n            assert (\n                input_len is None\n            ), '\"--apply-chat-template\" is not supported when \"--input-len\" is specified.'\n\n        request_records = []\n        for prompt, input_token_ids, output_length in self._tokenized_dataset:\n            input_length = len(input_token_ids)\n            # If the request does not have enough length, discard it.\n            if input_len is not None and input_length < input_len + 4 * input_len_std:\n                continue\n\n            if input_len is not None:\n                input_length = round(\n                    float(np.random.normal(loc=input_len, scale=input_len_std, size=1)[0])\n                )\n                input_token_ids = input_token_ids[:input_length]\n                input_truncated = True\n            else:\n                input_truncated = False\n            if output_len is not None:\n                output_length = round(\n                    float(np.random.normal(loc=output_len, scale=output_len_std, size=1)[0])\n                )\n            elif output_length <= 1:\n                continue\n            request_records.append(\n                RequestRecord(\n                    chat_cmpl=ChatCompletionRequest(\n                        messages=[\n                            {\n                                \"role\": \"user\",\n                                \"content\": (\n                                    self.tokenizer.decode(input_token_ids)\n                                    if input_truncated\n                                    else prompt\n                                ),\n                            }\n                        ],\n                        model=\"\",\n                        max_tokens=output_length,\n                    ),\n                    metrics=Metrics(\n                        success=False,\n                        start_time=0,\n                        finish_time=0,\n                        end_to_end_latency_s=0,\n                        input_tokens=len(input_token_ids),\n                    ),\n                )\n            )\n        return request_records\n\n\nclass AzureLLMInferenceDataset(Dataset):  # pylint: disable=too-few-public-methods\n    \"\"\"The dataset class for AzureLLMInference dataset.\n    Reference: https://github.com/Azure/AzurePublicDataset\n    \"\"\"\n\n    timestamp_available: bool = True\n\n    def __init__(self, dataset_path: str, tokenizer: AutoTokenizer) -> None:\n        df = pd.read_csv(dataset_path)\n        self.tokenizer = tokenizer\n\n        # Filter out the conversations with less than 2 turns.\n        self.dataset = [\n            (\n                entry[\"TIMESTAMP\"],\n                min(\n                    entry[\"ContextTokens\"],\n                    tokenizer.model_max_length,\n                    self.truncate_length,\n                ),\n                min(\n                    entry[\"GeneratedTokens\"],\n                    tokenizer.model_max_length,\n                    self.truncate_length,\n                ),\n            )\n            for _, entry in df.iterrows()\n            if entry[\"ContextTokens\"] >= 4 and entry[\"GeneratedTokens\"] >= 4\n        ]\n\n    def generate_request_records(  # pylint: disable=too-many-locals\n        self,\n        input_len: Optional[int],\n        output_len: Optional[int],\n        input_len_std: float = 0.0,\n        output_len_std: float = 0.0,\n    ) -> List[RequestRecord]:\n        time_fmt = \"%Y-%m-%d %H:%M:%S.%f\"\n        start_time = datetime.strptime(self.dataset[0][0][:-1], time_fmt)\n        request_records = []\n        for timestamp, input_length, output_length in self.dataset:\n            # If the request does not have enough length, discard it.\n            if input_len is not None and input_length < input_len + 4 * input_len_std:\n                continue\n\n            if input_len is not None:\n                input_length = round(\n                    float(np.random.normal(loc=input_len, scale=input_len_std, size=1)[0])\n                )\n            if output_len is not None:\n                output_length = round(\n                    float(np.random.normal(loc=output_len, scale=output_len_std, size=1)[0])\n                )\n            elif output_length <= 1:\n                continue\n\n            prompt_token_ids = [\n                random.randint(0, self.tokenizer.vocab_size - 1) for _ in range(input_length)\n            ]\n            while True:\n                # Adjust the token ids until the retokenization on the decoded string\n                # matches the required input length.\n                prompt = self.tokenizer.decode(prompt_token_ids)\n                retokenized_token_ids = self.tokenizer.encode(prompt, add_special_tokens=False)\n                if len(retokenized_token_ids) < input_length:\n                    prompt_token_ids = retokenized_token_ids + [\n                        random.randint(0, self.tokenizer.vocab_size - 1)\n                        for _ in range(input_length - len(retokenized_token_ids))\n                    ]\n                elif len(retokenized_token_ids) > input_length:\n                    prompt_token_ids = retokenized_token_ids[:input_length]\n                else:\n                    break\n\n            time_diff = (datetime.strptime(timestamp[:-1], time_fmt) - start_time).total_seconds()\n            request_records.append(\n                RequestRecord(\n                    chat_cmpl=ChatCompletionRequest(\n                        messages=[{\"role\": \"user\", \"content\": prompt}],\n                        model=\"\",\n                        max_tokens=output_length,\n                    ),\n                    timestamp=time_diff,\n                    metrics=Metrics(\n                        success=False,\n                        start_time=0,\n                        finish_time=0,\n                        end_to_end_latency_s=0,\n                        input_tokens=input_length,\n                    ),\n                )\n            )\n        return request_records\n\n\nSUPPORTED_DATASET = [\n    \"sharegpt\",\n    \"llmperf\",\n    \"json-mode-eval\",\n    \"loogle\",\n    \"react\",\n    \"wildchat\",\n    \"azure-llm-inference\",\n]\n\n\ndef create_dataset(  # pylint: disable=too-many-return-statements,too-many-branches\n    args: argparse.Namespace, tokenizer: AutoTokenizer\n) -> Dataset:\n    \"\"\"Create a dataset instance with regard to the specified dataset kind and file path.\"\"\"\n    if args.dataset_path is not None and not isinstance(args.dataset_path, str):\n        raise TypeError(f\"Invalid dataset path {args.dataset_path}. Please use a string.\")\n    if args.dataset is None and args.dataset_path is not None:\n        # Auto-detect the dataset kind by looking into the dataset path.\n        if \"sharegpt\" in args.dataset_path.lower():\n            args.dataset = \"sharegpt\"\n        else:\n            raise ValueError(\n                f\"Unable to detect the dataset kind from dataset path {args.dataset_path}. \"\n                'Please specify the dataset kind via \"--dataset\".'\n            )\n    if args.dataset == \"sharegpt\":\n        if args.dataset_path is None:\n            raise ValueError(\n                'ShareGPT dataset requires dataset path. Please specify it with \"--dataset-path\".'\n            )\n        return ShareGPTDataset(args.dataset_path, tokenizer, args.apply_chat_template)\n    if args.dataset == \"llmperf\":\n        if args.dataset_path is None:\n            raise ValueError(\n                'LLMPerf dataset requires dataset path. Please specify it with \"--dataset-path\".'\n            )\n        assert (\n            args.apply_chat_template is False\n        ), \"LLMPerf dataset does not support applying chat template\"\n        return LLMPerfDataset(\n            args.dataset_path,\n            (args.num_requests + args.num_warmup_requests) * 4,\n            tokenizer,\n        )\n    if args.dataset == \"json-mode-eval\":\n        assert (\n            args.apply_chat_template is False\n        ), \"JSON mode evaluation does not support applying chat template\"\n        return JSONModeEvalDataset(tokenizer)\n    if args.dataset == \"loogle\":\n        if args.dataset_path is None:\n            raise ValueError(\n                'Loogle dataset requires a testset name. Please specify it with \"--dataset-path\".'\n            )\n        assert (\n            args.apply_chat_template is False\n        ), \"Loogle dataset does not support applying chat template\"\n        return LoogleDataset(tokenizer, testset_name=args.dataset_path)\n    if args.dataset == \"react\":\n        if args.dataset_path is None:\n            raise ValueError(\n                'ReAct dataset requires dataset path. Please specify it with \"--dataset-path\".'\n            )\n        assert (\n            args.apply_chat_template is False\n        ), \"ReAct dataset does not support applying chat template\"\n        return ReActDataset(args.dataset_path, tokenizer)\n    if args.dataset == \"wildchat\":\n        return WildChatDataset(tokenizer, args.apply_chat_template)\n    if args.dataset == \"azure-llm-inference\":\n        if args.dataset_path is None:\n            raise ValueError(\n                \"AzureLLMInference dataset requires dataset path. \"\n                'Please specify it with \"--dataset-path\".'\n            )\n        assert (\n            args.apply_chat_template is False\n        ), \"AzureLLMInference dataset does not support applying chat template\"\n        return AzureLLMInferenceDataset(args.dataset_path, tokenizer)\n    raise ValueError(f\"Unrecognized dataset {args.dataset}\")\n"
  },
  {
    "path": "python/mlc_llm/bench/evaluation/gsm8k.py",
    "content": "\"\"\"Eval GSM8K with MLCEngine.\"\"\"\n\nimport argparse\nimport asyncio\nimport json\nimport random\nimport re\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import List, Literal, Optional\n\nimport tqdm\n\nfrom mlc_llm import AsyncMLCEngine\n\nDEVICES = [\"cuda\", \"rocm\", \"metal\", \"vulkan\"]\nANSWER_TRIGGER = \"The answer is\"\nINVALID_ANS = \"[invalid]\"\n\n\ndef extract_answer(text: str, regex: re.Pattern, select_index: int) -> str:\n    \"\"\"Extract the answer from the text.\"\"\"\n    match_all = regex.findall(text)\n    if len(match_all) == 0:\n        return INVALID_ANS\n    match = match_all[select_index]\n    if isinstance(match, tuple):\n        match = [m for m in match if m][0]\n    match_str: str = match.strip()\n    match_str = match_str.lstrip(\"$\").rstrip(\".\").replace(\",\", \"\")\n    return match_str\n\n\ndef extract_ground_truth(text: str) -> str:\n    \"\"\"Extract the ground truth from the text.\"\"\"\n    return extract_answer(text, re.compile(r\"#### (\\-?[0-9\\.\\,]+)\"), 0)\n\n\ndef strict_extract_answer(text: str) -> str:\n    \"\"\"Strictly extract the answer from the text.\"\"\"\n    return extract_answer(text, re.compile(r\"The answer is \\$?(\\-?[0-9\\.\\,]+).\"), 0)\n\n\ndef flexible_extract_answer(text: str) -> str:\n    \"\"\"Extract the last number from the text.\"\"\"\n    return extract_answer(text, re.compile(r\"(-?[$0-9.,]{2,})|(-?[0-9]+)\"), -1)\n\n\ndef create_few_shot_prompt(n_shot: int, use_cot: bool, random_order=False) -> str:\n    \"\"\"\n    Create a prompt for the few-shot learning task.\n\n    Note\n    ----\n    The examples are taken from the paper https://arxiv.org/pdf/2201.11903.pdf page 35.\n    \"\"\"\n    question, chain, answer = [], [], []\n\n    question.append(\n        \"There are 15 trees in the grove. \"\n        \"Grove workers will plant trees in the grove today. \"\n        \"After they are done, there will be 21 trees. \"\n        \"How many trees did the grove workers plant today?\"\n    )\n    chain.append(\n        \"There are 15 trees originally. \"\n        \"Then there were 21 trees after some more were planted. \"\n        \"So there must have been 21 - 15 = 6.\"\n    )\n    answer.append(\"6\")\n\n    question.append(\n        \"If there are 3 cars in the parking lot and 2 more cars arrive, \"\n        \"how many cars are in the parking lot?\"\n    )\n    chain.append(\"There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.\")\n    answer.append(\"5\")\n\n    question.append(\n        \"Leah had 32 chocolates and her sister had 42. If they ate 35, \"\n        \"how many pieces do they have left in total?\"\n    )\n    chain.append(\n        \"Originally, Leah had 32 chocolates. \"\n        \"Her sister had 42. So in total they had 32 + 42 = 74. \"\n        \"After eating 35, they had 74 - 35 = 39.\"\n    )\n    answer.append(\"39\")\n\n    question.append(\n        \"Jason had 20 lollipops. He gave Denny some lollipops. Now Jason \"\n        \"has 12 lollipops. How many lollipops did Jason give to Denny?\"\n    )\n    chain.append(\n        \"Jason started with 20 lollipops. Then he had 12 after giving some \"\n        \"to Denny. So he gave Denny 20 - 12 = 8.\"\n    )\n    answer.append(\"8\")\n\n    question.append(\n        \"Shawn has five toys. For Christmas, he got two toys each from his \"\n        \"mom and dad. How many toys does he have now?\"\n    )\n    chain.append(\n        \"Shawn started with 5 toys. If he got 2 toys each from his mom and \"\n        \"dad, then that is 4 more toys. 5 + 4 = 9.\"\n    )\n    answer.append(\"9\")\n\n    question.append(\n        \"There were nine computers in the server room. Five more computers \"\n        \"were installed each day, from monday to thursday. \"\n        \"How many computers are now in the server room?\"\n    )\n    chain.append(\n        \"There were originally 9 computers. For each of 4 days, 5 more \"\n        \"computers were added. So 5 * 4 = 20 computers were added. \"\n        \"9 + 20 is 29.\"\n    )\n    answer.append(\"29\")\n\n    question.append(\n        \"Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On \"\n        \"wednesday, he lost 2 more. \"\n        \"How many golf balls did he have at the end of wednesday?\"\n    )\n    chain.append(\n        \"Michael started with 58 golf balls. After losing 23 on tuesday, \"\n        \"he had 58 - 23 = 35. After losing 2 more, \"\n        \"he had 35 - 2 = 33 golf balls.\"\n    )\n    answer.append(\"33\")\n\n    question.append(\n        \"Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\"\n    )\n    chain.append(\n        \"Olivia had 23 dollars. \"\n        \"5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. \"\n        \"So she has 23 - 15 dollars left. 23 - 15 is 8.\"\n    )\n    answer.append(\"8\")\n\n    index_list = list(range(len(question)))\n    if random_order:\n        random.shuffle(index_list)\n\n    prompt = \"\"\n    for i in index_list[:n_shot]:\n        if use_cot:\n            prompt += f\"Q: {question[i]}\\nA: {chain[i]} {ANSWER_TRIGGER} {answer[i]}.\\n\\n\"\n        else:\n            prompt += f\"Question: {question[i]}\\nAnswer: {ANSWER_TRIGGER} {answer[i]}.\\n\\n\"\n    return prompt\n\n\ndef create_prompt(question: str, n_shot: int, use_cot: bool, random_order: bool = False) -> str:\n    \"\"\"Create a prompt for the few-shot learning task.\"\"\"\n    prompt = create_few_shot_prompt(n_shot, use_cot, random_order)\n    if use_cot:\n        prompt += f\"Q: {question}\\nA:\"\n    else:\n        prompt += f\"Question: {question}\\nAnswer:\"\n    return prompt\n\n\ndef parse_args():\n    \"\"\"Parse command line arguments.\"\"\"\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, required=True)\n    parser.add_argument(\n        \"--dataset\", type=Path, required=True, help=\"Path to GSM8K test dataset home.\"\n    )\n    parser.add_argument(\"--device\", type=str, choices=[\"auto\"] + DEVICES, default=\"auto\")\n    parser.add_argument(\"--model-lib\", type=str, default=None)\n    parser.add_argument(\"--n-shot\", type=int, default=8)\n    parser.add_argument(\"--disable_cot\", action=\"store_true\", default=False)\n    parser.add_argument(\"-bs\", \"--batch-size\", type=int, default=16)\n    parser.add_argument(\"--log-dir\", type=Path, default=None)\n    return parser.parse_args()\n\n\nasync def send_request(\n    async_engine: AsyncMLCEngine,\n    prompts: List[str],\n    semaphore: asyncio.Semaphore,\n):\n    \"\"\"Send the calibration requests to the engine.\"\"\"\n    tasks = []\n\n    async def generate_task(prompt):\n        async with semaphore:\n            return await async_engine.completions.create(\n                prompt=prompt,\n                stream=False,\n                max_tokens=512,\n                stop=[\"Q:\", \"Question:\"],\n                temperature=0.0,\n            )\n\n    for prompt in prompts:\n        task = asyncio.create_task(generate_task(prompt))\n        tasks.append(task)\n\n    return await tqdm.asyncio.tqdm.gather(*tasks)\n\n\nasync def evaluate(  # pylint: disable=too-many-arguments, too-many-locals\n    model: str,\n    device: str,\n    dataset: Path,\n    model_lib: Optional[str],\n    n_shot: int,\n    use_cot: bool,\n    batch_size: int,\n    log_dir: Optional[Path],  # pylint: disable=redefined-outer-name\n):\n    \"\"\"Evaluate GSM8K for the model.\"\"\"\n    mode: Literal[\"local\", \"interactive\", \"server\"] = (\n        \"server\" if batch_size > 4 else \"interactive\" if batch_size == 1 else \"local\"\n    )\n    async_engine = AsyncMLCEngine(model, device=device, model_lib=model_lib, mode=mode)\n\n    with open(dataset / \"test.jsonl\", \"r\", encoding=\"utf-8\") as file:\n        tests = [json.loads(line) for line in file]\n\n    prompts = [create_prompt(test[\"question\"], n_shot, use_cot) for test in tests]\n    responses = await send_request(async_engine, prompts, asyncio.Semaphore(batch_size))\n    assert len(responses) == len(tests)\n\n    num_strict_correct, num_flexible_correct = 0, 0\n    num_tests = len(tests)\n    logs = []\n\n    for response, test in zip(responses, tests):\n        response_text = response.choices[0].text.strip()\n        gt_answer = extract_ground_truth(test[\"answer\"])\n        assert gt_answer != INVALID_ANS\n        strict_answer = strict_extract_answer(response_text)\n        flexible_answer = flexible_extract_answer(response_text)\n\n        if gt_answer == strict_extract_answer(response_text):\n            # If the answer is exactly the same as the response, then it is correct\n            num_strict_correct += 1\n            num_flexible_correct += 1\n\n        elif gt_answer == flexible_extract_answer(response_text):\n            # Try flexible extract if the strict match fails\n            num_flexible_correct += 1\n\n        logs.append(\n            {\n                \"question\": test[\"question\"],\n                \"response\": response_text,\n                \"ground_truth\": gt_answer,\n                \"strict_answer\": strict_answer,\n                \"flexible_answer\": flexible_answer,\n                \"strict_match\": gt_answer == strict_answer,\n                \"flexible_match\": gt_answer == flexible_answer,\n            }\n        )\n\n    results = {\n        \"config\": {\n            \"model\": model,\n            \"device\": device,\n            \"model_lib\": model_lib,\n            \"n_shot\": n_shot,\n            \"use_cot\": use_cot,\n        },\n        \"results\": {\n            \"strict_match\": num_strict_correct,\n            \"flexible_match\": num_flexible_correct,\n            \"total\": num_tests,\n        },\n    }\n    print(\n        f\"Strict Matching Accuracy: {num_strict_correct} / {num_tests} = \"\n        f\"{num_strict_correct /num_tests * 100:.2f}%\"\n    )\n    print(\n        f\"Flexible Matching Accuracy: {num_flexible_correct} / {num_tests} = \"\n        f\"{num_flexible_correct /num_tests * 100:.2f}%\"\n    )\n\n    if log_dir:\n        with open(log_dir / \"summary.json\", \"w\", encoding=\"utf-8\") as f:\n            json.dump(results, f, indent=2)\n        with open(log_dir / \"logs.json\", \"w\", encoding=\"utf-8\") as f:\n            json.dump(logs, f, indent=2)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    start_time = datetime.now()\n    log_dir: Optional[Path] = None\n    if args.log_dir is not None:\n        time_dir = start_time.strftime(\"%Y-%m-%d_%H-%M-%S\")\n        log_dir = args.log_dir / time_dir\n        log_dir.mkdir(parents=True, exist_ok=True)\n    asyncio.run(\n        evaluate(\n            model=args.model,\n            device=args.device,\n            dataset=args.dataset,\n            model_lib=args.model_lib,\n            n_shot=args.n_shot,\n            use_cot=not args.disable_cot,\n            batch_size=args.batch_size,\n            log_dir=log_dir,\n        )\n    )\n    end_time = datetime.now()\n    print(f\"Time used: {end_time - start_time}\")\n"
  },
  {
    "path": "python/mlc_llm/bench/evaluation/mmlu.py",
    "content": "\"\"\"Eval MMLU with MLCEngine.\"\"\"\n\nimport argparse\nimport asyncio\nimport csv\nimport json\nimport string\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional\n\nimport numpy as np\nimport tqdm\n\nfrom mlc_llm import AsyncMLCEngine\n\nSUBJECTS = [\n    \"abstract_algebra\",\n    \"anatomy\",\n    \"astronomy\",\n    \"business_ethics\",\n    \"clinical_knowledge\",\n    \"college_biology\",\n    \"college_chemistry\",\n    \"college_computer_science\",\n    \"college_mathematics\",\n    \"college_medicine\",\n    \"college_physics\",\n    \"computer_security\",\n    \"conceptual_physics\",\n    \"econometrics\",\n    \"electrical_engineering\",\n    \"elementary_mathematics\",\n    \"formal_logic\",\n    \"global_facts\",\n    \"high_school_biology\",\n    \"high_school_chemistry\",\n    \"high_school_computer_science\",\n    \"high_school_european_history\",\n    \"high_school_geography\",\n    \"high_school_government_and_politics\",\n    \"high_school_macroeconomics\",\n    \"high_school_mathematics\",\n    \"high_school_microeconomics\",\n    \"high_school_physics\",\n    \"high_school_psychology\",\n    \"high_school_statistics\",\n    \"high_school_us_history\",\n    \"high_school_world_history\",\n    \"human_aging\",\n    \"human_sexuality\",\n    \"international_law\",\n    \"jurisprudence\",\n    \"logical_fallacies\",\n    \"machine_learning\",\n    \"management\",\n    \"marketing\",\n    \"medical_genetics\",\n    \"miscellaneous\",\n    \"moral_disputes\",\n    \"moral_scenarios\",\n    \"nutrition\",\n    \"philosophy\",\n    \"prehistory\",\n    \"professional_accounting\",\n    \"professional_law\",\n    \"professional_medicine\",\n    \"professional_psychology\",\n    \"public_relations\",\n    \"security_studies\",\n    \"sociology\",\n    \"us_foreign_policy\",\n    \"virology\",\n    \"world_religions\",\n]\nPADDING_LEN = max(len(subject) for subject in SUBJECTS)\nDEVICES = [\"cuda\", \"rocm\", \"metal\", \"vulkan\"]\nPROMPT_TEMPLATE = string.Template(\"$Q\\nA. $A\\nB. $B\\nC. $C\\nD. $D\\nAnswer:\")\n\n\ndef parse_args():\n    \"\"\"Parse command line arguments.\"\"\"\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, required=True)\n    parser.add_argument(\n        \"--dataset\", type=Path, required=True, help=\"Path to MMLU test dataset home.\"\n    )\n    parser.add_argument(\"--device\", type=str, choices=[\"auto\"] + DEVICES, default=\"auto\")\n    parser.add_argument(\"--model-lib\", type=str, default=None)\n    parser.add_argument(\"-s\", \"--subject\", nargs=\"+\", type=str, choices=SUBJECTS, default=SUBJECTS)\n    parser.add_argument(\"-bs\", \"--batch-size\", type=int, default=16)\n    parser.add_argument(\"--log-dir\", type=Path, default=None)\n    return parser.parse_args()\n\n\nasync def send_request(\n    async_engine: AsyncMLCEngine,\n    prompts: List[str],\n    semaphore: asyncio.Semaphore,\n    subject: str,\n):\n    \"\"\"Send the calibration requests to the engine.\"\"\"\n    tasks = []\n\n    async def generate_task(prompt):\n        async with semaphore:\n            return await async_engine.completions.create(\n                prompt=prompt,\n                stream=False,\n                max_tokens=1,\n                temperature=1.0,\n                logprobs=True,\n                top_logprobs=5,\n            )\n\n    for prompt in prompts:\n        task = asyncio.create_task(generate_task(prompt))\n        tasks.append(task)\n\n    return await tqdm.asyncio.tqdm.gather(\n        *tasks,\n        desc=f\"Running {subject.ljust(PADDING_LEN)}\",\n        bar_format=\"{desc} {percentage:3.0f}%|{bar}{r_bar}\",\n    )\n\n\nasync def evaluate(  # pylint: disable=too-many-arguments, too-many-locals\n    model: str,\n    device: str,\n    dataset: Path,\n    model_lib: Optional[str],\n    subjects: List[str],\n    semaphore: asyncio.Semaphore,\n    log_dir: Optional[Path],  # pylint: disable=redefined-outer-name\n):\n    \"\"\"Evaluate MMLU for the model.\"\"\"\n    async_engine = AsyncMLCEngine(model, device=device, model_lib=model_lib, mode=\"server\")\n\n    results: Dict[str, Any] = {}\n    for subject in subjects:\n        with open(dataset / \"test\" / f\"{subject}_test.csv\", encoding=\"utf-8\") as csvfile:\n            tests = list(csv.reader(csvfile, delimiter=\",\", quotechar='\"'))\n            assert all(len(test) == 6 for test in tests)\n\n        logs = []\n        num_correct = 0\n        prompts = [\n            PROMPT_TEMPLATE.substitute(Q=test[0], A=test[1], B=test[2], C=test[3], D=test[4])\n            for test in tests\n        ]\n        responses = await send_request(async_engine, prompts, semaphore, subject)\n\n        assert len(responses) == len(tests)\n        for response, test in zip(responses, tests):\n            token_logprobs = {}\n            logprobs = response.choices[0].logprobs.content[0].top_logprobs\n            for logprob in logprobs:\n                if logprob.token not in token_logprobs:\n                    token_logprobs[logprob.token] = logprob.logprob\n\n            abcd_logprobs = {}\n            for choice in [\"A\", \"B\", \"C\", \"D\"]:\n                abcd_logprobs[choice] = token_logprobs[choice] if choice in token_logprobs else -100\n\n            pred = {0: \"A\", 1: \"B\", 2: \"C\", 3: \"D\"}[int(np.argmax(list(abcd_logprobs.values())))]\n            num_correct += pred == test[5]\n\n            logs.append(\n                {\n                    \"Question\": {\n                        \"Q\": test[0],\n                        \"A\": test[1],\n                        \"B\": test[2],\n                        \"C\": test[3],\n                        \"D\": test[4],\n                    },\n                    \"Answer\": test[5],\n                    \"Response\": {\n                        \"pred\": pred,\n                        \"logprobs\": list(abcd_logprobs.values()),\n                    },\n                }\n            )\n\n        results[subject] = {\n            \"correct\": num_correct,\n            \"total\": len(tests),\n            \"accuracy\": num_correct / len(tests),\n        }\n\n        if log_dir:\n            with open(log_dir / \"subjects\" / f\"{subject}.json\", \"w\", encoding=\"utf-8\") as f:\n                json.dump(logs, f, indent=2)\n\n    total_correct, total_tests = 0, 0\n    for subject, v in results.items():\n        num_correct, num_tests, accuracy = v[\"correct\"], v[\"total\"], v[\"accuracy\"]\n        print(f\"{subject}: {num_correct} / {num_tests} = {accuracy * 100:.2f}%\")\n        total_correct += num_correct\n        total_tests += num_tests\n\n    total_accuracy = total_correct / total_tests\n    results[\"total\"] = {\n        \"correct\": total_correct,\n        \"total\": total_tests,\n        \"accuracy\": total_accuracy,\n    }\n    print(f\"Total accuracy: {total_correct} / {total_tests} = {total_accuracy * 100:.2f}%\")\n\n    if log_dir:\n        results = {\n            \"config\": {\n                \"model\": model,\n                \"device\": device,\n                \"model_lib\": model_lib,\n                \"subjects\": subjects,\n            },\n            \"results\": results,\n        }\n        with open(log_dir / \"summary.json\", \"w\", encoding=\"utf-8\") as f:\n            json.dump(results, f, indent=2)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    start_time = datetime.now()\n    log_dir: Optional[Path] = None\n    if args.log_dir is not None:\n        time_dir = start_time.strftime(\"%Y-%m-%d_%H-%M-%S\")\n        log_dir = args.log_dir / time_dir\n        (log_dir / \"subjects\").mkdir(parents=True, exist_ok=True)\n    asyncio.run(\n        evaluate(\n            model=args.model,\n            device=args.device,\n            dataset=args.dataset,\n            model_lib=args.model_lib,\n            subjects=args.subject,\n            semaphore=asyncio.Semaphore(args.batch_size),\n            log_dir=log_dir,\n        )\n    )\n    end_time = datetime.now()\n    print(f\"Time used: {end_time - start_time}\")\n"
  },
  {
    "path": "python/mlc_llm/bench/request_processor.py",
    "content": "\"\"\"MLC LLM Bench Request\"\"\"\n\nimport argparse\nimport asyncio\nimport concurrent.futures\nimport copy\nimport os\nimport random\nimport time\nfrom typing import Any, Callable, Dict, List, Optional\n\nimport numpy as np\nimport requests\nfrom tqdm import tqdm\nfrom transformers import AutoTokenizer  # pylint: disable=import-error\n\nfrom mlc_llm.bench.api_endpoint import APIEndPoint\nfrom mlc_llm.bench.dataset import Dataset\nfrom mlc_llm.bench.request_record import GroupedRequestRecord, RequestRecord\nfrom mlc_llm.protocol.openai_api_protocol import (\n    ChatCompletionMessage,\n    ChatCompletionRequest,\n    DebugConfig,\n)\nfrom mlc_llm.support import logging\n\nlogger = logging.getLogger(__name__)\n\n\nclass RequestProcessor:  # pylint: disable=too-few-public-methods\n    \"\"\"The request processor base class.\n    Each processor can take a list of RequestRecord, applying the process,\n    and returning the processed RequestRecord in the end.\n    \"\"\"\n\n    def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:\n        raise NotImplementedError()\n\n\nclass LogMessage(RequestProcessor):  # pylint: disable=too-few-public-methods\n    \"\"\"The processor that prints the logger message.\"\"\"\n\n    def __init__(self, message: str) -> None:\n        self.message = message\n\n    def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:\n        logger.info(self.message)\n        return request_records\n\n\nclass SampleRequests(RequestProcessor):  # pylint: disable=too-few-public-methods\n    \"\"\"The processor that samples requests out from the given request list.\"\"\"\n\n    def __init__(self, num_requests: int, take_first_x_requests: bool = False) -> None:\n        self.num_requests = num_requests\n        # If `take_first_x_requests` is True, the first `num_requests` requests\n        # are returned and sampling will not happen.\n        self.take_first_x_requests = take_first_x_requests\n\n    def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:\n        assert len(request_records) > 0, \"Empty input request record.\"\n\n        # We expect the input request records to be all grouped or all plain.\n        if isinstance(request_records[0], GroupedRequestRecord):\n            assert all(isinstance(record, GroupedRequestRecord) for record in request_records)\n            return self._sample_from_grouped_request_records(request_records)\n\n        assert all(not isinstance(record, GroupedRequestRecord) for record in request_records)\n        return self._sample_from_plain_request_records(request_records)\n\n    def _sample_from_plain_request_records(\n        self, request_records: List[RequestRecord]\n    ) -> List[RequestRecord]:\n        samples: List[RequestRecord] = []\n        if self.take_first_x_requests:\n            if len(request_records) < self.num_requests:\n                raise ValueError(\n                    f\"Insufficient requests. Requiring {self.num_requests} requests \"\n                    f\"but only {len(request_records)} are available.\"\n                )\n            samples = copy.deepcopy(list(request_records[: self.num_requests]))\n        else:\n            while len(samples) < self.num_requests:\n                # Create a new list so that the in-place shuffle does not mutate the input list.\n                records = list(request_records)\n                random.shuffle(records)\n                samples += copy.deepcopy(records)\n            samples = samples[: self.num_requests]\n        for i, record in enumerate(samples):\n            record.request_id = i\n        return samples\n\n    def _sample_from_grouped_request_records(\n        self, grouped_request_records: List[GroupedRequestRecord]\n    ) -> List[RequestRecord]:\n        num_total_available_requests = sum(\n            len(record.records) for record in grouped_request_records\n        )\n        if self.num_requests > num_total_available_requests:\n            raise ValueError(\n                \"Due to the existence of shared common prefixes, we do not allow \"\n                \"benchmarking with requests more than the available requests in the dataset. \"\n                f\"The required number of requests {self.num_requests} exceeds the \"\n                f\"number of total available requests {num_total_available_requests}.\"\n            )\n\n        # Create a new list so that the in-place shuffle does not mutate the input list.\n        records = list(grouped_request_records)\n        if not self.take_first_x_requests:\n            random.shuffle(records)\n        remaining = self.num_requests\n        samples: List[RequestRecord] = []\n        for grouped_request_record in grouped_request_records:\n            num_used_requests = min(len(grouped_request_record.records), remaining)\n            samples += grouped_request_record.records[:num_used_requests]\n            remaining -= num_used_requests\n            if remaining == 0:\n                break\n        for i, record in enumerate(samples):\n            record.request_id = i\n        return samples\n\n\nclass AttachModelName(RequestProcessor):  # pylint: disable=too-few-public-methods\n    \"\"\"The processor that attaches model name to requests.\"\"\"\n\n    def __init__(self, model: str) -> None:\n        self.model = model\n\n    def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:\n        for request_record in request_records:\n            request_record.chat_cmpl.model = self.model\n        return request_records\n\n\nclass AttachRequestRateTimestamp(RequestProcessor):  # pylint: disable=too-few-public-methods\n    \"\"\"The processor that applies timestamps to the requests.\"\"\"\n\n    def __init__(self, request_rate: np.float32) -> None:\n        self.request_rate = request_rate\n\n    def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:\n        timestamp = 0.0\n        for request_record in request_records:\n            assert request_record.timestamp is None, \"The request record already has a timestamp\"\n            request_record.timestamp = timestamp\n            timestamp += float(np.random.exponential(1.0 / self.request_rate))\n        return request_records\n\n\nclass AttachExecutionFeature(RequestProcessor):  # pylint: disable=too-few-public-methods\n    \"\"\"The processor that attaches execution features to all requests\"\"\"\n\n    def __init__(self, exec_feature: Dict[str, Any]) -> None:\n        self.exec_feature = exec_feature\n\n    def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:\n        for request_record in request_records:\n            assert request_record.metrics is not None\n            request_record.metrics.exec_feature = self.exec_feature\n        return request_records\n\n\nclass AttachStreamFlag(RequestProcessor):  # pylint: disable=too-few-public-methods\n    \"\"\"The processor that attaches the stream flag to the requests.\"\"\"\n\n    def __init__(self, stream: Optional[bool]) -> None:\n        self.stream = stream\n\n    def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:\n        if self.stream is None:\n            return request_records\n        for request_record in request_records:\n            request_record.chat_cmpl.stream = self.stream\n        return request_records\n\n\nclass AttachSamplingOptions(RequestProcessor):  # pylint: disable=too-few-public-methods\n    \"\"\"The processor that attaches the stream flag to the requests.\"\"\"\n\n    def __init__(self, temperature: float, top_p: float, ignore_eos: bool) -> None:\n        self.temperature = temperature\n        self.top_p = top_p\n        self.ignore_eos = ignore_eos\n\n    def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:\n        for request_record in request_records:\n            request_record.chat_cmpl.temperature = self.temperature\n            request_record.chat_cmpl.top_p = self.top_p\n            request_record.chat_cmpl.frequency_penalty = 0.0\n            request_record.chat_cmpl.presence_penalty = 0.0\n            request_record.chat_cmpl.tool_choice = \"none\"\n            if self.ignore_eos:\n                request_record.chat_cmpl.debug_config = DebugConfig(ignore_eos=True)\n        return request_records\n\n\nclass ScaleTimestamp(RequestProcessor):  # pylint: disable=too-few-public-methods\n    \"\"\"Scale the timestamp of requests by the given scale factor.\"\"\"\n\n    def __init__(self, timestamp_scale: float):\n        self.timestamp_scale = timestamp_scale\n\n    def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:\n        for request_record in request_records:\n            if request_record.timestamp is None:\n                raise ValueError(\n                    f\"The timestamp of request {request_record} has not been initialized.\"\n                )\n            request_record.timestamp *= self.timestamp_scale\n        return request_records\n\n\nclass MetricAnalyzer(RequestProcessor):  # pylint: disable=too-few-public-methods\n    \"\"\"The processor that analyzes the raw benchmark results and computes more detailed metrics.\"\"\"\n\n    def __init__(self, tokenizer: AutoTokenizer) -> None:\n        self.tokenizer = tokenizer\n\n    def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:\n        updated_records = []\n        for request_record in request_records:\n            metrics = request_record.metrics\n            if not metrics.success:\n                assert request_record.error_msg is not None\n                continue\n\n            metrics.output_tokens = len(\n                self.tokenizer.encode(request_record.output_str, add_special_tokens=False)\n            )\n            first_chunk_output_tokens = len(\n                self.tokenizer.encode(\n                    request_record.first_chunk_output_str, add_special_tokens=False\n                )\n            )\n            if metrics.output_tokens <= first_chunk_output_tokens:\n                metrics.success = False\n                request_record.error_msg = (\n                    f\"Total output token num ({metrics.output_tokens}) equals \"\n                    f'the first chunk output token. Output text \"{request_record.output_str}\", '\n                    f'first chunk output text \"{request_record.first_chunk_output_str}\"'\n                )\n                continue\n            assert metrics.input_tokens > 0, \"Invalid prompt tokens\"\n            metrics.inter_token_latency_s = metrics.end_to_end_latency_s / metrics.output_tokens\n            if metrics.time_to_first_token_s is None:\n                metrics.time_to_first_token_s = 0\n            metrics.time_per_output_token_s = (\n                metrics.end_to_end_latency_s - metrics.time_to_first_token_s\n            ) / (metrics.output_tokens - first_chunk_output_tokens)\n            updated_records.append(request_record)\n        return updated_records\n\n\nclass WarmupAndRun(RequestProcessor):  # pylint: disable=too-few-public-methods,line-too-long\n    \"\"\"The processor that runs warmup first and then runs the benchmark with the given pipeline.\"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments\n        self,\n        num_warmup_requests: int,\n        num_benchmark_requests: int,\n        pipeline: RequestProcessor,\n        cuda_profile_url: Optional[str],\n        fake_warmup: bool = False,\n    ) -> None:\n        self.num_warmup_requests = num_warmup_requests\n        self.num_benchmark_requests = num_benchmark_requests\n        self.pipeline = pipeline\n        self.cuda_profile_url = cuda_profile_url\n        self.fake_warmup = fake_warmup\n\n    def generate_fake_warmup_requests(  # pylint: disable=missing-function-docstring\n        self, num_warmup_requests: int, example_request: RequestRecord\n    ) -> List[RequestRecord]:\n        records = []\n        for _ in range(num_warmup_requests):\n            record = copy.deepcopy(example_request)\n            record.chat_cmpl = ChatCompletionRequest(\n                messages=[\n                    {\n                        \"role\": \"user\",\n                        \"content\": \"Please output arbitrary coherent sentences. Do not output eos token.\",  # pylint: disable=line-too-long\n                    }\n                ],\n                model=\"\",\n                max_tokens=128,\n            )\n            records.append(record)\n        return records\n\n    def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:\n        # Warmup\n        if self.fake_warmup:\n            assert len(request_records) == self.num_benchmark_requests\n            benchmark_requests = request_records\n            example_request = benchmark_requests[0]\n            warmup_requests = self.generate_fake_warmup_requests(\n                self.num_warmup_requests, example_request=example_request\n            )\n        else:\n            assert len(request_records) == self.num_warmup_requests + self.num_benchmark_requests\n            benchmark_requests = request_records[: -self.num_warmup_requests]\n            warmup_requests = request_records[-self.num_warmup_requests :]\n        for request_record in warmup_requests:\n            request_record.timestamp = 0 if request_record.timestamp is not None else None\n        warmup_requests = self._process_warmup_requests(warmup_requests)\n        logger.info(\"Warmup with %d request(s)...\", self.num_warmup_requests)\n        self.pipeline(warmup_requests)\n\n        # Then run benchmark\n        if self.cuda_profile_url is not None:\n            cuda_profiler_start_url = self.cuda_profile_url + \"/debug/cuda_profiler_start\"\n            cuda_profiler_start_response = requests.post(cuda_profiler_start_url, timeout=60)\n            assert cuda_profiler_start_response.status_code == 200\n        logger.info(\"Warmup finished. Start benchmarking...\")\n        updated_request_records = self.pipeline(benchmark_requests)\n        if self.cuda_profile_url is not None:\n            cuda_profiler_stop_url = self.cuda_profile_url + \"/debug/cuda_profiler_stop\"\n            cuda_profiler_stop_response = requests.post(cuda_profiler_stop_url, timeout=60)\n            assert cuda_profiler_stop_response.status_code == 200\n\n        return updated_request_records\n\n    def _process_warmup_requests(self, warmup_requests: List[RequestRecord]) -> List[RequestRecord]:\n        if len(warmup_requests) == 0:\n            return warmup_requests\n        # NOTE: to warm up the server for as more different batch sizes as possible,\n        # we usese 128 output tokens for the first request and use two more tokens\n        # for every followup request.\n        # Setting a high temperature and top-p to avoid early stop as much as possible.\n        warmup_requests[0].chat_cmpl.max_tokens = 128\n        for i in range(1, len(warmup_requests)):\n            warmup_requests[i].chat_cmpl.max_tokens = (\n                warmup_requests[i - 1].chat_cmpl.max_tokens + 1\n            )\n            warmup_requests[i].chat_cmpl.temperature = 2.0\n            warmup_requests[i].chat_cmpl.top_p = 1.0\n        return warmup_requests\n\n\nclass SequentialProcessor(RequestProcessor):  # pylint: disable=too-few-public-methods\n    \"\"\"The processor that sequentially applies a list of processors in order.\"\"\"\n\n    processors: List[RequestProcessor]\n\n    def __init__(self, *processors: RequestProcessor) -> None:\n        self.processors = list(processors)\n\n    def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:\n        for processor in self.processors:\n            request_records = processor(request_records)\n        return request_records\n\n\nclass Executor(RequestProcessor):  # pylint: disable=too-few-public-methods\n    \"\"\"The executor base class, denoting the kind of benchmark mode.\"\"\"\n\n    def __init__(\n        self,\n        f_create_api_endpoint: Callable[[], APIEndPoint],\n        num_processes: int,\n        disable_tqdm: bool,\n    ) -> None:\n        self.f_create_api_endpoint = f_create_api_endpoint\n        self.disable_tqdm = disable_tqdm\n        self.num_processes = num_processes\n\n    def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:\n        raise NotImplementedError()\n\n\nclass FixedConcurrentRequestExecutor(Executor):  # pylint: disable=too-few-public-methods\n    \"\"\"The benchmark executor of fixing the number of concurrent requests.\"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments\n        self,\n        f_create_api_endpoint: Callable[[], APIEndPoint],\n        num_processes: Optional[int],\n        disable_tqdm: bool,\n        num_concurrent_requests: int,\n        multi_round: bool,\n    ) -> None:\n        if num_processes is None:\n            # We assign each process at most 32 concurrent requests to send\n            # so that the asyncio pressure will not be too much.\n            num_processes = min((num_concurrent_requests + 31) // 32, 10)\n        super().__init__(f_create_api_endpoint, num_processes, disable_tqdm)\n        self.num_concurrent_requests = num_concurrent_requests\n        self.multi_round = multi_round\n\n    def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:\n        partitions: List[List[RequestRecord]] = [\n            request_records[slice(i, len(request_records), self.num_processes)]\n            for i in range(self.num_processes)\n        ]\n        # Package \"tokenizers\" reports warnings with multiprocessing.\n        # We disable \"TOKENIZERS_PARALLELISM\" to depress the warnings.\n        os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n        pbar = None if self.disable_tqdm else tqdm(total=len(request_records))\n        with concurrent.futures.ProcessPoolExecutor(max_workers=self.num_processes) as pool:\n            futures = [\n                pool.submit(\n                    FixedConcurrentRequestExecutor._process_task,\n                    self.f_create_api_endpoint,\n                    partition,\n                    self.num_concurrent_requests // self.num_processes\n                    + int(i < self.num_concurrent_requests % self.num_processes),\n                    self.multi_round,\n                )\n                for i, partition in enumerate(partitions)\n            ]\n            results: List[RequestRecord] = []\n            for i, future in enumerate(concurrent.futures.as_completed(futures)):\n                results.extend(future.result())\n                if pbar is not None:\n                    pbar.update(len(partitions[i]))\n\n        return results\n\n    @staticmethod\n    def _process_task(\n        f_create_api_endpoint: Callable[[], APIEndPoint],\n        request_records: List[RequestRecord],\n        num_concurrent_requests: int,\n        multi_round: bool,\n    ) -> List[RequestRecord]:\n        if len(request_records) == 0:\n            return []\n        chat_history: List[List[ChatCompletionMessage]] = [\n            [] for _ in range(num_concurrent_requests)\n        ]\n\n        async def process_task_impl(\n            f_create_api_endpoint: Callable[[], APIEndPoint],\n            request_records: List[RequestRecord],\n            num_concurrent_requests: int,\n            multi_round: bool,\n        ) -> List[RequestRecord]:\n            api_endpoint = f_create_api_endpoint()\n            updated_request_records: List[RequestRecord] = [None for _ in request_records]\n            async with api_endpoint:\n                num_sent_request = 0\n\n                async def _task(i: int) -> None:\n                    nonlocal num_sent_request\n                    while True:\n                        if num_sent_request == len(request_records):\n                            break\n                        idx = num_sent_request\n                        num_sent_request += 1\n                        request = request_records[idx]\n\n                        if multi_round:\n                            request.chat_cmpl.messages = (\n                                chat_history[i] + request.chat_cmpl.messages\n                            )\n\n                        updated_request_records[idx] = await api_endpoint(request)\n\n                        if multi_round:\n                            chat_history[i] = updated_request_records[idx].chat_cmpl.messages + [\n                                ChatCompletionMessage(\n                                    content=updated_request_records[idx].output_str,\n                                    role=\"assistant\",\n                                )\n                            ]\n\n                tasks = [asyncio.create_task(_task(i)) for i in range(num_concurrent_requests)]\n                await asyncio.gather(*tasks)\n\n            return updated_request_records\n\n        return asyncio.run(\n            process_task_impl(\n                f_create_api_endpoint,\n                request_records,\n                num_concurrent_requests,\n                multi_round,\n            )\n        )\n\n\nclass FixTimestampExecutor(Executor):  # pylint: disable=too-few-public-methods\n    \"\"\"The benchmark executor of fixing the timestamps of sending requests.\"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments\n        self,\n        f_create_api_endpoint: Callable[[], APIEndPoint],\n        num_processes: Optional[int],\n        disable_tqdm: bool,\n        max_schedule_gap: float,\n        num_requests: int,\n    ) -> None:\n        if num_processes is None:\n            # We assign each process at most 32 requests to send\n            # so that the asyncio pressure will not be too much.\n            num_processes = min((num_requests + 31) // 32, 10)\n        super().__init__(f_create_api_endpoint, num_processes, disable_tqdm)\n        self.max_schedule_gap = max_schedule_gap\n        self.num_requests = num_requests\n\n    def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:\n        assert len(request_records) > 0\n        assert all(request_record.timestamp is not None for request_record in request_records)\n        # Sort the request records in timestamp ascending order before partitioning.\n        request_records.sort(key=lambda request_record: request_record.timestamp)\n        base_timestamp = request_records[0].timestamp\n        partitions: List[List[RequestRecord]] = [\n            request_records[slice(i, len(request_records), self.num_processes)]\n            for i in range(self.num_processes)\n        ]\n        base_sys_time = time.time()\n        # Package \"tokenizers\" reports warnings with multiprocessing.\n        # We disable \"TOKENIZERS_PARALLELISM\" to depress the warnings.\n        os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n\n        pbar = None if self.disable_tqdm else tqdm(total=len(request_records))\n        with concurrent.futures.ProcessPoolExecutor(max_workers=self.num_processes) as pool:\n            futures = [\n                pool.submit(\n                    FixTimestampExecutor._process_task,\n                    self.f_create_api_endpoint,\n                    partition,\n                    base_timestamp,\n                    base_sys_time,\n                    self.max_schedule_gap,\n                )\n                for partition in partitions\n            ]\n            results: List[RequestRecord] = []\n            for i, future in enumerate(concurrent.futures.as_completed(futures)):\n                results.extend(future.result())\n                if pbar is not None:\n                    pbar.update(len(partitions[i]))\n\n        return results\n\n    @staticmethod\n    def _process_task(\n        f_create_api_endpoint: Callable[[], APIEndPoint],\n        request_records: List[RequestRecord],\n        base_timestamp: float,\n        base_sys_time: float,\n        max_schedule_gap: float,\n    ) -> List[RequestRecord]:\n        if len(request_records) == 0:\n            return []\n\n        async def process_task_impl(\n            f_create_api_endpoint: Callable[[], APIEndPoint],\n            request_records: List[RequestRecord],\n            base_timestamp: float,\n            base_sys_time: float,\n            max_schedule_gap: float,\n        ) -> List[RequestRecord]:\n            api_endpoint = f_create_api_endpoint()\n            loop = asyncio.get_running_loop()\n            # Get the delta time to convert system time to the loop time.\n            # We must use the system time `time.time()` which is consistent across processes.\n            loop_sys_delta_time = loop.time() - time.time()\n            updated_request_records: List[RequestRecord] = []\n            async with api_endpoint:\n\n                async def _task(request_record: RequestRecord) -> None:\n                    updated_request_records.append(await api_endpoint(request_record))\n\n                tasks = []\n                for request_record in request_records:\n                    launch_time = (\n                        (request_record.timestamp - base_timestamp)\n                        + (base_sys_time + max_schedule_gap)\n                        + loop_sys_delta_time\n                    )\n                    loop.call_at(\n                        launch_time,\n                        lambda record: tasks.append(asyncio.create_task(_task(record))),\n                        request_record,\n                    )\n                    # Sleep to allow runs of other scheduled tasks if any.\n                    await asyncio.sleep(max(launch_time - loop.time() - max_schedule_gap, 0))\n\n                # Sleep until all the tasks are launched.\n                await asyncio.sleep(launch_time - loop.time() + max_schedule_gap)\n                # Wait for all tasks to be scheduled\n                assert len(tasks) == len(request_records)\n                await asyncio.gather(*tasks)\n\n            assert len(updated_request_records) == len(request_records)\n            return updated_request_records\n\n        return asyncio.run(\n            process_task_impl(\n                f_create_api_endpoint,\n                request_records,\n                base_timestamp,\n                base_sys_time,\n                max_schedule_gap,\n            )\n        )\n\n\ndef create_pipelines(  # pylint: disable=too-many-branches\n    args: argparse.Namespace,\n    f_create_api_endpoint: Callable[[], APIEndPoint],\n    dataset: Dataset,\n) -> List[RequestProcessor]:\n    \"\"\"Creating request processing pipelines with regard to the specified args.\"\"\"\n    cuda_profile_url = f\"http://{args.host}:{args.port}\" if args.cuda_profile else None\n    pipelines: List[RequestProcessor] = []\n    if args.num_concurrent_requests is not None:\n        if args.request_rate is not None:\n            raise ValueError(\n                'Both \"num_concurrent_requests\" and \"request_rate\" are specified. '\n                \"Please specify only one of them.\"\n            )\n        if args.replay_timestamp_scale is not None:\n            raise ValueError(\n                \"Dataset replay is unsupported when fixing number of concurrent requests.\"\n            )\n        for num_concurrent_requests in args.num_concurrent_requests:\n            num_warmup_requests = (\n                args.num_warmup_requests\n                if args.num_warmup_requests is not None\n                else num_concurrent_requests\n            )\n            pipelines.append(\n                SequentialProcessor(\n                    LogMessage(f\"Fixing number of concurrent requests: {num_concurrent_requests}\"),\n                    SampleRequests(args.num_requests + num_warmup_requests),\n                    AttachModelName(args.tokenizer),\n                    AttachStreamFlag(args.stream),\n                    AttachSamplingOptions(args.temperature, args.top_p, args.ignore_eos),\n                    AttachExecutionFeature({\"num_concurrent_requests\": num_concurrent_requests}),\n                    WarmupAndRun(\n                        num_warmup_requests=num_warmup_requests,\n                        num_benchmark_requests=args.num_requests,\n                        pipeline=FixedConcurrentRequestExecutor(\n                            f_create_api_endpoint,\n                            args.num_process_workers,\n                            args.disable_tqdm,\n                            num_concurrent_requests,\n                            args.multi_round,\n                        ),\n                        cuda_profile_url=cuda_profile_url,\n                        fake_warmup=dataset.require_fake_warmup,\n                    ),\n                )\n            )\n        return pipelines\n    if args.request_rate is not None:\n        if args.num_warmup_requests is None:\n            raise ValueError(\n                \"Please specify the number of warmup requests via \"\n                '\"--num-warmup-requests\" when fixing request rate.'\n            )\n        if args.replay_timestamp_scale is not None:\n            raise ValueError(\"Dataset replay is unsupported when fixing request rates.\")\n        num_total_requests = int(\n            args.num_requests if not args.per_gpu_workload else args.num_requests * args.num_gpus\n        )\n        if dataset.require_fake_warmup:\n            num_samples = num_total_requests\n        else:\n            num_samples = num_total_requests + args.num_warmup_requests\n        return [\n            SequentialProcessor(\n                LogMessage(f\"Fixing request rate: {request_rate}\"),\n                SampleRequests(num_samples),\n                AttachModelName(args.tokenizer),\n                AttachRequestRateTimestamp(\n                    request_rate if not args.per_gpu_workload else request_rate * args.num_gpus\n                ),\n                AttachStreamFlag(args.stream),\n                AttachSamplingOptions(args.temperature, args.top_p, args.ignore_eos),\n                AttachExecutionFeature({\"request_rate\": float(request_rate)}),\n                WarmupAndRun(\n                    num_warmup_requests=args.num_warmup_requests,\n                    num_benchmark_requests=num_total_requests,\n                    pipeline=FixTimestampExecutor(\n                        f_create_api_endpoint,\n                        args.num_process_workers,\n                        args.disable_tqdm,\n                        args.max_schedule_gap,\n                        args.num_requests,\n                    ),\n                    cuda_profile_url=cuda_profile_url,\n                    fake_warmup=dataset.require_fake_warmup,\n                ),\n            )\n            for request_rate in args.request_rate\n        ]\n\n    # Default: dataset replay mode\n    # The dataset must come with timestamps.\n    if not dataset.timestamp_available:\n        raise ValueError(\n            \"The dataset does not have timestamps, so dataset replay is unsupported. \"\n            'Please specify one of \"num_concurrent_requests\" '\n            'and \"request_rate\".'\n        )\n    if args.per_gpu_workload:\n        raise ValueError(\"Fixing per-GPU workload is not compatible with dataset replay.\")\n    if args.num_warmup_requests is None:\n        raise ValueError(\n            \"Please specify the number of warmup requests via \"\n            '\"--num-warmup-requests\" for dataset replay.'\n        )\n    timestamp_scale = args.replay_timestamp_scale or 1.0\n    if dataset.require_fake_warmup:\n        num_samples = args.num_requests\n    else:\n        num_samples = args.num_requests + args.num_warmup_requests\n    return [\n        SequentialProcessor(\n            LogMessage(f\"Dataset replay with time scaling of {timestamp_scale}\"),\n            SampleRequests(num_samples, take_first_x_requests=True),\n            AttachModelName(args.tokenizer),\n            ScaleTimestamp(timestamp_scale),\n            AttachStreamFlag(args.stream),\n            AttachSamplingOptions(args.temperature, args.top_p, args.ignore_eos),\n            AttachExecutionFeature({\"timestamp_scale\": timestamp_scale}),\n            WarmupAndRun(\n                num_warmup_requests=args.num_warmup_requests,\n                num_benchmark_requests=args.num_requests,\n                pipeline=FixTimestampExecutor(\n                    f_create_api_endpoint,\n                    args.num_process_workers,\n                    args.disable_tqdm,\n                    args.max_schedule_gap,\n                    args.num_requests,\n                ),\n                cuda_profile_url=cuda_profile_url,\n                fake_warmup=dataset.require_fake_warmup,\n            ),\n        )\n    ]\n"
  },
  {
    "path": "python/mlc_llm/bench/request_record.py",
    "content": "\"\"\"MLC LLM Bench Request\"\"\"\n\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport pandas as pd  # pylint: disable=import-error\nfrom pydantic import BaseModel\n\nfrom mlc_llm.protocol.openai_api_protocol import ChatCompletionRequest\nfrom mlc_llm.support import logging\n\nlogger = logging.getLogger(__name__)\n\n\nclass ServerMetrics(BaseModel):\n    \"\"\"The metrics from the server side.\"\"\"\n\n    input_tokens: int\n    prefill_tokens: int\n    output_tokens: int\n    end_to_end_latency_s: float\n    prefill_tokens_per_s: float\n    inter_token_latency_s: float\n    time_per_output_token_s: float\n    time_to_first_token_s: Optional[float] = None\n\n\nclass Metrics(BaseModel):\n    \"\"\"The list of metric keys\"\"\"\n\n    success: bool\n    start_time: float\n    finish_time: float\n    end_to_end_latency_s: float\n\n    input_tokens: Optional[int] = None\n    output_tokens: Optional[int] = None\n    inter_token_latency_s: Optional[float] = None\n    time_per_output_token_s: Optional[float] = None\n    time_to_first_token_s: Optional[float] = None\n    server_metrics: Optional[ServerMetrics] = None\n\n    exec_feature: Optional[Dict[str, Any]] = None\n\n\nclass RequestRecord(BaseModel):\n    \"\"\"The request records collected from LLM inference requests.\"\"\"\n\n    request_id: Optional[int] = None\n    chat_cmpl: ChatCompletionRequest\n    output_str: Optional[str] = None\n    first_chunk_output_str: str = \"\"\n    timestamp: Optional[float] = None\n    metrics: Optional[Metrics] = None\n    error_msg: Optional[str] = None\n\n\nclass GroupedRequestRecord(RequestRecord):\n    \"\"\"The data structure for request record groups.\n    For datasets that have common prefix sharing, the request records\n    that share a same common prefix will be wrapped in a GroupedRequestRecord\n    at the beginning.\n    \"\"\"\n\n    records: List[RequestRecord]\n\n\ndef generate_metrics_summary(\n    request_records: List[RequestRecord],\n    num_total_requests: int,\n    num_gpus: int,\n) -> Dict[str, Any]:\n    \"\"\"Computes summary statistics across all metrics collected.\n    Return a dictionary as the report.\n    \"\"\"\n    num_completed_requests = len(request_records)\n    assert num_completed_requests <= num_total_requests\n    request_metrics = [record.metrics for record in request_records]\n    duration = (\n        max(metrics.finish_time for metrics in request_metrics)\n        - min(metrics.start_time for metrics in request_metrics)\n        if num_completed_requests > 0\n        else 1e-5\n    )\n\n    report = _compute_metrics_statistics(request_metrics)\n    report[\"num_gpus\"] = num_gpus\n    report[\"duration\"] = duration\n    report[\"num_total_requests\"] = num_total_requests\n    report[\"num_completed_requests\"] = num_completed_requests\n    report[\"request_throughput\"] = num_completed_requests / duration\n\n    total_input_tokens = sum(metric.input_tokens for metric in request_metrics)\n    total_output_tokens = sum(metric.output_tokens for metric in request_metrics)\n    report[\"total_input_tokens\"] = total_input_tokens\n    report[\"total_output_tokens\"] = total_output_tokens\n    report[\"input_token_throughput\"] = total_input_tokens / duration\n    report[\"input_token_throughput_per_gpu\"] = report[\"input_token_throughput\"] / num_gpus\n    report[\"output_token_throughput\"] = total_output_tokens / duration\n    report[\"output_token_throughput_per_gpu\"] = report[\"output_token_throughput\"] / num_gpus\n\n    # Generate the server metrics statistics\n    server_metrics = [metric.server_metrics for metric in request_metrics if metric.server_metrics]\n    server_report = _compute_metrics_statistics(server_metrics)\n    if server_report is not None and len(server_report) > 0:\n        report[\"server_metrics\"] = server_report\n\n    report = {\n        \"exec_feature\": (\n            request_records[0].metrics.exec_feature if num_completed_requests > 0 else None\n        ),\n        **report,\n    }\n    return report\n\n\ndef _compute_metrics_statistics(\n    metrics: List[Union[Metrics, ServerMetrics]],\n) -> Dict[str, Any]:\n    \"\"\"\n    Compute the statistics of the metrics.\n\n    Parameters\n    ----------\n    metrics : List[Union[Metrics, ServerMetrics]]\n        The list of metrics to get the statistics.\n\n    Returns\n    -------\n    report : Dict\n        The statistics of the metrics.\n    \"\"\"\n    if not metrics:\n        return {}\n\n    report: Dict = {}\n    df = pd.DataFrame([metric.model_dump() for metric in metrics])\n    for key, _ in metrics[0].model_fields.items():\n        if key in [\n            \"success\",\n            \"start_time\",\n            \"finish_time\",\n            \"server_metrics\",\n            \"exec_feature\",\n        ]:\n            continue\n        if key in df.columns:\n            series = df[key].dropna()\n            report[key] = {\n                \"quantiles\": {\n                    f\"p{int(q * 100)}\": v\n                    for q, v in series.quantile([0.25, 0.5, 0.75, 0.9, 0.95, 0.99]).items()\n                },\n                \"mean\": series.mean(),\n                \"min\": series.min(),\n                \"max\": series.max(),\n                \"stddev\": series.std(),\n            }\n    return report\n\n\ndef convert_reports_to_df(reports: List[Dict[str, Any]]) -> pd.DataFrame:\n    \"\"\"Convert benchmark reports to pandas DataFrame.\"\"\"\n\n    def _flatten_dict(d: Dict[str, Any], parent_key: str = \"\") -> Dict[str, Any]:\n        items: List[Tuple[str, Any]] = []\n        for key, value in d.items():\n            new_key = f\"{parent_key}.{key}\" if parent_key != \"\" else key\n            if isinstance(value, dict):\n                items.extend(_flatten_dict(value, new_key).items())\n            else:\n                items.append((new_key, value))\n        return dict(items)\n\n    return pd.DataFrame([_flatten_dict(report) for report in reports])\n\n\ndef pretty_print_report(report: Dict[str, Any]) -> None:  # pylint: disable=too-many-statements\n    \"\"\"Pretty print the metrics report.\"\"\"\n\n    def _print(report: Dict[str, Any], server_metrics: bool):  # pylint: disable=too-many-statements\n        # pylint: disable=line-too-long\n        # fmt: off\n        title = \"Benchmark Result\"\n        if server_metrics:\n            title += \" (server side)\"\n        print(f\" {title} \".center(50, \"=\"))\n        if not server_metrics:\n            print(f\"{'Total requests:':<40} {report['num_total_requests']:<10}\")\n            print(f\"{'Completed requests:':<40} {report['num_completed_requests']:<10}\")\n            print(f\"{'Duration (s):':<40} {report['duration']:<10.2f}\")\n            print(f\"{'Num GPUs:':<40} {report['num_gpus']:<10}\")\n            print(f\"{'Total input tokens:':<40} {report['total_input_tokens']:<10}\")\n            print(f\"{'Total output tokens:':<40} {report['total_output_tokens']:<10}\")\n            print(f\"{'Request throughput (req/s):':<40} {report['request_throughput']:<10.2f}\")\n            print(f\"{'Input token throughput (tok/s):':<40} {report['input_token_throughput']:<10.2f}\")\n            print(f\"{'Input token throughput per GPU (tok/s):':<40} {report['input_token_throughput_per_gpu']:<10.2f}\")\n            print(f\"{'Output token throughput (tok/s):':<40} {report['output_token_throughput']:<10.2f}\")\n            print(f\"{'Output token throughput per GPU (tok/s):':<40} {report['output_token_throughput_per_gpu']:<10.2f}\")\n\n        if report[\"num_completed_requests\"] == 0:\n            return\n        ttft = report[\"time_to_first_token_s\"]\n        print(\" Time to First Token (TTFT, ms) \".center(50, \"-\"))\n        print(f\"{'Mean:':<40} {ttft['mean'] * 1000:<10.2f}\")\n        print(f\"{'Stddev:':<40} {ttft['stddev'] * 1000:<10.2f}\")\n        print(f\"{'P25:':<40} {ttft['quantiles']['p25'] * 1000:<10.2f}\")\n        print(f\"{'P50:':<40} {ttft['quantiles']['p50'] * 1000:<10.2f}\")\n        print(f\"{'P75:':<40} {ttft['quantiles']['p75'] * 1000:<10.2f}\")\n        print(f\"{'P90:':<40} {ttft['quantiles']['p90'] * 1000:<10.2f}\")\n        print(f\"{'P95:':<40} {ttft['quantiles']['p95'] * 1000:<10.2f}\")\n        print(f\"{'P99:':<40} {ttft['quantiles']['p99'] * 1000:<10.2f}\")\n        print(f\"{'Min:':<40} {ttft['min'] * 1000:<10.2f}\")\n        print(f\"{'Max:':<40} {ttft['max'] * 1000:<10.2f}\")\n\n        tpot = report[\"time_per_output_token_s\"]\n        print(\" Time per Output Token (TPOT, ms) \".center(50, \"-\"))\n        print(f\"{'Mean:':<40} {tpot['mean'] * 1000:<10.2f}\")\n        print(f\"{'Stddev:':<40} {tpot['stddev'] * 1000:<10.2f}\")\n        print(f\"{'P25:':<40} {tpot['quantiles']['p25'] * 1000:<10.2f}\")\n        print(f\"{'P50:':<40} {tpot['quantiles']['p50'] * 1000:<10.2f}\")\n        print(f\"{'P75:':<40} {tpot['quantiles']['p75'] * 1000:<10.2f}\")\n        print(f\"{'P90:':<40} {tpot['quantiles']['p90'] * 1000:<10.2f}\")\n        print(f\"{'P95:':<40} {tpot['quantiles']['p95'] * 1000:<10.2f}\")\n        print(f\"{'P99:':<40} {tpot['quantiles']['p99'] * 1000:<10.2f}\")\n        print(f\"{'Min:':<40} {tpot['min'] * 1000:<10.2f}\")\n        print(f\"{'Max:':<40} {tpot['max'] * 1000:<10.2f}\")\n\n        itl = report[\"inter_token_latency_s\"]\n        print(\" Inter-Token Latency (ms) \".center(50, \"-\"))\n        print(f\"{'Mean:':<40} {itl['mean'] * 1000:<10.2f}\")\n        print(f\"{'Stddev:':<40} {itl['stddev'] * 1000:<10.2f}\")\n        print(f\"{'P25:':<40} {itl['quantiles']['p25'] * 1000:<10.2f}\")\n        print(f\"{'P50:':<40} {itl['quantiles']['p50'] * 1000:<10.2f}\")\n        print(f\"{'P75:':<40} {itl['quantiles']['p75'] * 1000:<10.2f}\")\n        print(f\"{'P90:':<40} {itl['quantiles']['p90'] * 1000:<10.2f}\")\n        print(f\"{'P95:':<40} {itl['quantiles']['p95'] * 1000:<10.2f}\")\n        print(f\"{'P99:':<40} {itl['quantiles']['p99'] * 1000:<10.2f}\")\n        print(f\"{'Min:':<40} {itl['min'] * 1000:<10.2f}\")\n        print(f\"{'Max:':<40} {itl['max'] * 1000:<10.2f}\")\n\n        e2e_latency = report[\"end_to_end_latency_s\"]\n        print(\" End-to-End Latency (ms) \".center(50, \"-\"))\n        print(f\"{'Mean:':<40} {e2e_latency['mean'] * 1000:<10.2f}\")\n        print(f\"{'Stddev:':<40} {e2e_latency['stddev'] * 1000:<10.2f}\")\n        print(f\"{'P25:':<40} {e2e_latency['quantiles']['p25'] * 1000:<10.2f}\")\n        print(f\"{'P50:':<40} {e2e_latency['quantiles']['p50'] * 1000:<10.2f}\")\n        print(f\"{'P75:':<40} {e2e_latency['quantiles']['p75'] * 1000:<10.2f}\")\n        print(f\"{'P90:':<40} {e2e_latency['quantiles']['p90'] * 1000:<10.2f}\")\n        print(f\"{'P95:':<40} {e2e_latency['quantiles']['p95'] * 1000:<10.2f}\")\n        print(f\"{'P99:':<40} {e2e_latency['quantiles']['p99'] * 1000:<10.2f}\")\n        print(f\"{'Min:':<40} {e2e_latency['min'] * 1000:<10.2f}\")\n        print(f\"{'Max:':<40} {e2e_latency['max'] * 1000:<10.2f}\")\n\n        input_tokens = report[\"input_tokens\"]\n        print(\" Input Tokens \".center(50, \"-\"))\n        print(f\"{'Mean:':<40} {input_tokens['mean']:<1}\")\n        print(f\"{'Stddev:':<40} {input_tokens['stddev']:<1}\")\n        print(f\"{'P25:':<40} {input_tokens['quantiles']['p25']:<1}\")\n        print(f\"{'P50:':<40} {input_tokens['quantiles']['p50']:<1}\")\n        print(f\"{'P95:':<40} {input_tokens['quantiles']['p95']:<1}\")\n        print(f\"{'Min:':<40} {input_tokens['min']:<1}\")\n        print(f\"{'Max:':<40} {input_tokens['max']:<1}\")\n\n        output_tokens = report[\"output_tokens\"]\n        print(\" Output Tokens \".center(50, \"-\"))\n        print(f\"{'Mean:':<40} {output_tokens['mean']:<1}\")\n        print(f\"{'Stddev:':<40} {output_tokens['stddev']:<1}\")\n        print(f\"{'P25:':<40} {output_tokens['quantiles']['p25']:<1}\")\n        print(f\"{'P50:':<40} {output_tokens['quantiles']['p50']:<1}\")\n        print(f\"{'P95:':<40} {output_tokens['quantiles']['p95']:<1}\")\n        print(f\"{'Min:':<40} {output_tokens['min']:<1}\")\n        print(f\"{'Max:':<40} {output_tokens['max']:<1}\")\n\n        print(\"=\" * 50)\n\n    # fmt: on\n    # pylint: enable=line-too-long\n    _print(report, server_metrics=False)\n    if \"server_metrics\" in report:\n        _print(report[\"server_metrics\"], server_metrics=True)\n"
  },
  {
    "path": "python/mlc_llm/cli/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/cli/calibrate.py",
    "content": "\"\"\"Command line entrypoint of calibration.\"\"\"\n\nfrom mlc_llm.interface.calibrate import calibrate\nfrom mlc_llm.interface.help import HELP\nfrom mlc_llm.support.argparse import ArgumentParser\n\nfrom .serve import EngineConfigOverride\n\n\ndef main(argv):\n    \"\"\"Main entrypoint for calibration.\"\"\"\n    parser = ArgumentParser(\"MLC LLM Calibration CLI\")\n    parser.add_argument(\n        \"model\",\n        type=str,\n        help=HELP[\"model\"] + \" (required)\",\n    )\n    parser.add_argument(\n        \"--device\",\n        type=str,\n        default=\"auto\",\n        help=HELP[\"device_deploy\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--model-lib\",\n        type=str,\n        default=None,\n        help=HELP[\"model_lib\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--output\",\n        \"-o\",\n        type=str,\n        required=True,\n        help=HELP[\"output_calibration\"] + \" (required)\",\n    )\n    # Download dataset from\n    # https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json\n    parser.add_argument(\n        \"--dataset\",\n        type=str,\n        required=True,\n        help=HELP[\"calibration_dataset\"] + \" (required)\",\n    )\n\n    parser.add_argument(\n        \"--num-calibration-samples\",\n        type=int,\n        default=16,\n        help=HELP[\"num_calibration_samples\"] + ' (default: \"%(default)s\")',\n    )\n\n    parser.add_argument(\n        \"--seed\",\n        type=int,\n        default=0,\n        help=HELP[\"seed_calibrate\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--overrides\",\n        type=EngineConfigOverride.from_str,\n        default=\"\",\n        help=HELP[\"overrides_serve\"],\n    )\n\n    parsed = parser.parse_args(argv)\n    calibrate(\n        model=parsed.model,\n        device=parsed.device,\n        model_lib=parsed.model_lib,\n        output=parsed.output,\n        dataset=parsed.dataset,\n        num_calibration_samples=parsed.num_calibration_samples,\n        max_num_sequence=parsed.overrides.max_num_sequence,\n        max_total_sequence_length=parsed.overrides.max_total_seq_length,\n        prefill_chunk_size=parsed.overrides.prefill_chunk_size,\n        max_history_size=parsed.overrides.max_history_size,\n        gpu_memory_utilization=parsed.overrides.gpu_memory_utilization,\n        seed=parsed.seed,\n    )\n"
  },
  {
    "path": "python/mlc_llm/cli/chat.py",
    "content": "\"\"\"Command line entrypoint of chat.\"\"\"\n\nfrom mlc_llm.interface.chat import ModelConfigOverride, chat\nfrom mlc_llm.interface.help import HELP\nfrom mlc_llm.support.argparse import ArgumentParser\n\n\ndef main(argv):\n    \"\"\"Parse command line arguments and call `mlc_llm.interface.chat`.\"\"\"\n    parser = ArgumentParser(\"MLC LLM Chat CLI\")\n\n    parser.add_argument(\n        \"model\",\n        type=str,\n        help=HELP[\"model\"] + \" (required)\",\n    )\n    parser.add_argument(\n        \"--device\",\n        type=str,\n        default=\"auto\",\n        help=HELP[\"device_deploy\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--model-lib\",\n        type=str,\n        default=None,\n        help=HELP[\"model_lib\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--overrides\",\n        type=ModelConfigOverride.from_str,\n        default=\"\",\n        help=HELP[\"modelconfig_overrides\"] + ' (default: \"%(default)s\")',\n    )\n    parsed = parser.parse_args(argv)\n    chat(\n        model=parsed.model,\n        device=parsed.device,\n        model_lib=parsed.model_lib,\n        overrides=parsed.overrides,\n    )\n"
  },
  {
    "path": "python/mlc_llm/cli/check_device.py",
    "content": "\"\"\"Check if a device exists.\"\"\"\n\nimport os\nimport sys\n\nfrom tvm.runtime import Device\nfrom tvm.runtime import device as as_device\n\n\ndef _check_device(device: Device) -> bool:\n    try:\n        return bool(device.exist)\n    except:  # pylint: disable=bare-except\n        return False\n\n\ndef main():\n    \"\"\"Entrypoint for device check.\"\"\"\n    device_str = sys.argv[1]\n    device_ids = []\n    i = 0\n    while True:\n        if _check_device(as_device(device_str, i)):\n            device_ids.append(i)\n            i += 1\n            if device_str in [\"cpu\", \"llvm\"] and i > os.cpu_count() / 2:\n                break\n        else:\n            break\n    print(f\"check_device:{','.join(str(i) for i in device_ids)}\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/mlc_llm/cli/compile.py",
    "content": "\"\"\"Command line entrypoint of compilation.\"\"\"\n\nimport argparse\nimport json\nimport re\nfrom functools import partial\nfrom pathlib import Path\nfrom typing import Union\n\nfrom mlc_llm.interface.compile import (  # pylint: disable=redefined-builtin\n    ModelConfigOverride,\n    OptimizationFlags,\n    compile,\n)\nfrom mlc_llm.interface.help import HELP\nfrom mlc_llm.model import MODELS\nfrom mlc_llm.quantization import QUANTIZATION\nfrom mlc_llm.support.argparse import ArgumentParser\nfrom mlc_llm.support.auto_config import (\n    detect_mlc_chat_config,\n    detect_model_type,\n    detect_quantization,\n)\nfrom mlc_llm.support.auto_target import detect_system_lib_prefix, detect_target_and_host\n\n\ndef main(argv):\n    \"\"\"Parse command line arguments and call `mlc_llm.compiler.compile`.\"\"\"\n\n    def _parse_output(path: Union[str, Path]) -> Path:\n        path = Path(path)\n        if path.is_dir():\n            raise argparse.ArgumentTypeError(f\"Output cannot be a directory: {path}\")\n        parent = path.parent\n        if not parent.is_dir():\n            raise argparse.ArgumentTypeError(f\"Directory does not exist: {parent}\")\n        return path\n\n    def _parse_dir(path: Union[str, Path], auto_create: bool = False) -> Path:\n        path = Path(path)\n        if not auto_create and not path.is_dir():\n            raise argparse.ArgumentTypeError(f\"Directory does not exist: {path}\")\n        if auto_create and not path.is_dir():\n            path.mkdir(parents=True)\n        return path\n\n    def _check_system_lib_prefix(prefix: str) -> str:\n        pattern = r\"^[a-zA-Z_][a-zA-Z0-9_]*$\"\n        if prefix == \"\" or re.match(pattern, prefix):\n            return prefix\n        raise argparse.ArgumentTypeError(\n            \"Invalid prefix. It should only consist of \"\n            \"numbers (0-9), alphabets (A-Z, a-z) and underscore (_).\"\n        )\n\n    parser = ArgumentParser(\"mlc_llm compile\")\n    parser.add_argument(\n        \"model\",\n        type=detect_mlc_chat_config,\n        help=HELP[\"model\"] + \" (required)\",\n    )\n    parser.add_argument(\n        \"--quantization\",\n        type=str,\n        choices=list(QUANTIZATION.keys()),\n        help=HELP[\"quantization\"]\n        + \" (default: look up mlc-chat-config.json, choices: %(choices)s)\",\n    )\n    parser.add_argument(\n        \"--model-type\",\n        type=str,\n        default=\"auto\",\n        choices=[\"auto\"] + list(MODELS.keys()),\n        help=HELP[\"model_type\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--device\",\n        type=str,\n        default=\"auto\",\n        help=HELP[\"device_compile\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--host\",\n        type=str,\n        default=\"auto\",\n        help=HELP[\"host\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--opt\",\n        type=OptimizationFlags.from_str,\n        default=\"O2\",\n        help=HELP[\"opt\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--system-lib-prefix\",\n        type=str,\n        default=\"auto\",\n        help=HELP[\"system_lib_prefix\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--output\",\n        \"-o\",\n        type=_parse_output,\n        required=True,\n        help=HELP[\"output_compile\"] + \" (required)\",\n    )\n    parser.add_argument(\n        \"--overrides\",\n        type=ModelConfigOverride.from_str,\n        default=\"\",\n        help=HELP[\"overrides\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--debug-dump\",\n        type=partial(_parse_dir, auto_create=True),\n        default=None,\n        help=HELP[\"debug_dump\"] + \" (default: %(default)s)\",\n    )\n    parsed = parser.parse_args(argv)\n    target, build_func = detect_target_and_host(parsed.device, parsed.host)\n    parsed.model_type = detect_model_type(parsed.model_type, parsed.model)\n    parsed.quantization = detect_quantization(parsed.quantization, parsed.model)\n    parsed.system_lib_prefix = detect_system_lib_prefix(\n        parsed.device,\n        parsed.system_lib_prefix,\n        parsed.model_type.name,\n        parsed.quantization.name,\n    )\n    with open(parsed.model, \"r\", encoding=\"utf-8\") as config_file:\n        config = json.load(config_file)\n\n    compile(\n        config=config,\n        quantization=parsed.quantization,\n        model_type=parsed.model_type,\n        target=target,\n        opt=parsed.opt,\n        build_func=build_func,\n        system_lib_prefix=parsed.system_lib_prefix,\n        output=parsed.output,\n        overrides=parsed.overrides,\n        debug_dump=parsed.debug_dump,\n    )\n"
  },
  {
    "path": "python/mlc_llm/cli/convert_weight.py",
    "content": "\"\"\"Command line entrypoint of weight conversion.\"\"\"\n\nimport argparse\nfrom pathlib import Path\nfrom typing import Union\n\nfrom mlc_llm.interface.convert_weight import convert_weight\nfrom mlc_llm.interface.help import HELP\nfrom mlc_llm.model import MODELS\nfrom mlc_llm.quantization import QUANTIZATION\nfrom mlc_llm.support.argparse import ArgumentParser\nfrom mlc_llm.support.auto_config import detect_config, detect_model_type\nfrom mlc_llm.support.auto_device import detect_device\nfrom mlc_llm.support.auto_weight import detect_weight\n\n\ndef main(argv):\n    \"\"\"Parse command line argumennts and apply quantization.\"\"\"\n\n    def _parse_source(path: Union[str, Path], config_path: Path) -> Path:\n        if path == \"auto\":\n            return config_path.parent\n        path = Path(path)\n        if not path.exists():\n            raise argparse.ArgumentTypeError(f\"Model source does not exist: {path}\")\n        return path\n\n    def _parse_output(path: Union[str, Path]) -> Path:\n        path = Path(path)\n        if not path.is_dir():\n            path.mkdir(parents=True, exist_ok=True)\n        return path\n\n    def _parse_lora_adapter(path: Union[str, Path]) -> Path:\n        path = Path(path)\n        if not path.exists() or not path.is_dir():\n            raise argparse.ArgumentTypeError(f\"LoRA adapter directory does not exist: {path}\")\n        return path\n\n    parser = ArgumentParser(\"MLC AutoLLM Quantization Framework\")\n    parser.add_argument(\n        \"config\",\n        type=detect_config,\n        help=HELP[\"config\"] + \" (required)\",\n    )\n    parser.add_argument(\n        \"--quantization\",\n        type=str,\n        required=True,\n        choices=list(QUANTIZATION.keys()),\n        help=HELP[\"quantization\"] + \" (required, choices: %(choices)s)\",\n    )\n    parser.add_argument(\n        \"--model-type\",\n        type=str,\n        default=\"auto\",\n        choices=[\"auto\"] + list(MODELS.keys()),\n        help=HELP[\"model_type\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--device\",\n        default=\"auto\",\n        type=detect_device,\n        help=HELP[\"device_quantize\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--source\",\n        type=str,\n        default=\"auto\",\n        help=HELP[\"source\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--source-format\",\n        type=str,\n        choices=[\"auto\", \"huggingface-torch\", \"huggingface-safetensor\", \"awq\"],\n        default=\"auto\",\n        help=HELP[\"source_format\"] + ' (default: \"%(default)s\", choices: %(choices)s\")',\n    )\n    parser.add_argument(\n        \"--output\",\n        \"-o\",\n        type=_parse_output,\n        required=True,\n        help=HELP[\"output_quantize\"] + \" (required)\",\n    )\n    parser.add_argument(\n        \"--lora-adapter\",\n        type=_parse_lora_adapter,\n        default=None,\n        help=(\n            \"Path to a LoRA adapter directory in PEFT format. \"\n            \"When provided, adapter weights are merged into the base model before quantization.\"\n        ),\n    )\n\n    parsed = parser.parse_args(argv)\n    parsed.source, parsed.source_format = detect_weight(\n        _parse_source(parsed.source, parsed.config),\n        parsed.config,\n        parsed.source_format,\n    )\n    model = detect_model_type(parsed.model_type, parsed.config)\n    convert_weight(\n        config=parsed.config,\n        quantization=QUANTIZATION[parsed.quantization],\n        model=model,\n        device=parsed.device,\n        source=parsed.source,\n        source_format=parsed.source_format,\n        output=parsed.output,\n        lora_adapter=parsed.lora_adapter,\n    )\n"
  },
  {
    "path": "python/mlc_llm/cli/delivery.py",
    "content": "\"\"\"Continuous model delivery for MLC LLM models.\"\"\"\n\nimport argparse\nimport json\nimport os\nimport subprocess\nimport sys\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union\n\nfrom huggingface_hub import HfApi, snapshot_download  # pylint: disable=import-error\nfrom huggingface_hub.utils import HfHubHTTPError  # pylint: disable=import-error\nfrom pydantic import BaseModel, Field, ValidationError\n\nfrom mlc_llm.support import logging\nfrom mlc_llm.support.argparse import ArgumentParser\nfrom mlc_llm.support.style import bold, green, red\n\nlogger = logging.getLogger(__name__)\n\nGEN_CONFIG_OPTIONAL_ARGS = [\n    \"context_window_size\",\n    \"sliding_window_size\",\n    \"prefill_chunk_size\",\n    \"attention_sink_size\",\n    \"tensor_parallel_shards\",\n    \"pipeline_parallel_stages\",\n]\n\nT = TypeVar(\"T\", bound=\"BaseModel\")\n\n\nclass OverrideConfigs(BaseModel):\n    \"\"\"\n    The class that specifies the override configurations.\n    \"\"\"\n\n    context_window_size: Optional[int] = None\n    sliding_window_size: Optional[int] = None\n    prefill_chunk_size: Optional[int] = None\n    attention_sink_size: Optional[int] = None\n    tensor_parallel_shards: Optional[int] = None\n    pipeline_parallel_stages: Optional[int] = None\n\n\nclass ModelDeliveryTask(BaseModel):\n    \"\"\"\n    Example:\n    {\n        \"model_id\": \"Phi-3-mini-128k-instruct\",\n        \"model\": \"HF://microsoft/Phi-3-mini-128k-instruct\",\n        \"conv_template\": \"phi-3\",\n        \"quantization\": [\"q3f16_1\"],\n        \"overrides\": {\n            \"q3f16_1\": {\n                \"context_window_size\": 512\n            }\n        }\n    }\n    \"\"\"\n\n    model_id: str\n    model: str\n    conv_template: str\n    quantization: Union[List[str], str] = Field(default_factory=list)\n    overrides: Dict[str, OverrideConfigs] = Field(default_factory=dict)\n    destination: Optional[str] = None\n    gen_config_only: Optional[bool] = False\n\n\nclass ModelDeliveryList(BaseModel):\n    \"\"\"\n    The class that specifies the model delivery list.\n    \"\"\"\n\n    tasks: List[ModelDeliveryTask]\n    # For delivered log, the default destination and quantization fields are optional\n    default_destination: Optional[str] = None\n    default_quantization: List[str] = Field(default_factory=list)\n    default_overrides: Dict[str, OverrideConfigs] = Field(default_factory=dict)\n\n    @classmethod\n    def from_json(cls: Type[T], json_dict: Dict[str, Any]) -> T:\n        \"\"\"\n        Convert from a json dictionary.\n        \"\"\"\n        try:\n            return ModelDeliveryList.model_validate(json_dict)\n        except ValidationError as e:\n            logger.error(\"Error validating ModelDeliveryList: %s\", e)\n            raise e\n\n    def to_json(self) -> Dict[str, Any]:\n        \"\"\"\n        Convert to a json dictionary.\n        \"\"\"\n        return self.model_dump(exclude_none=True)\n\n\ndef _clone_repo(model: Union[str, Path], hf_local_dir: Optional[str]) -> str:\n    if isinstance(model, Path):\n        if not model.exists():\n            raise ValueError(f\"Invalid model source: {model}\")\n        return str(model)\n    prefixes, mlc_prefix = [\"HF://\", \"https://huggingface.co/\"], \"\"\n    mlc_prefix = next(p for p in prefixes if model.startswith(p))\n    if mlc_prefix:\n        repo_name = model[len(mlc_prefix) :]\n        model_name = repo_name.split(\"/\")[-1]\n        if hf_local_dir:\n            hf_local_dir = os.path.join(hf_local_dir, model_name)\n            logger.info(\"[HF] Downloading model to %s\", hf_local_dir)\n        return snapshot_download(repo_id=repo_name, local_dir=hf_local_dir)\n    result = Path(model)\n    if result.exists():\n        return model\n    raise ValueError(f\"Invalid model source: {model}\")\n\n\ndef _run_quantization(\n    model_info: ModelDeliveryTask,\n    repo: str,\n    api: HfApi,\n    output_dir: str,\n) -> bool:\n    logger.info(\"[HF] Creating repo https://huggingface.co/%s\", repo)\n    try:\n        api.create_repo(repo_id=repo, private=False)\n    except HfHubHTTPError as error:\n        if error.response.status_code != 409:\n            raise\n        logger.info(\"[HF] Repo already exists. Skipping creation.\")\n    succeeded = True\n    log_path = Path(output_dir) / \"logs.txt\"\n    with log_path.open(\"a\", encoding=\"utf-8\") as log_file:\n        assert isinstance(model_info.quantization, str)\n        logger.info(\"[MLC] Processing in directory: %s\", output_dir)\n        # Required arguments\n        cmd = [\n            sys.executable,\n            \"-m\",\n            \"mlc_llm\",\n            \"gen_config\",\n            model_info.model,\n            \"--quantization\",\n            model_info.quantization,\n            \"--conv-template\",\n            model_info.conv_template,\n            \"--output\",\n            output_dir,\n        ]\n        # Optional arguments\n        for optional_arg in GEN_CONFIG_OPTIONAL_ARGS:\n            optional_arg_val = getattr(model_info, optional_arg, None)\n            if optional_arg_val is not None:\n                # e.g. --context-window-size 4096\n                cmd += [\"--\" + optional_arg.replace(\"_\", \"-\"), str(optional_arg_val)]\n\n        print(\" \".join(cmd), file=log_file, flush=True)\n        subprocess.run(cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT, env=os.environ)\n        if not model_info.gen_config_only:\n            cmd = [\n                sys.executable,\n                \"-m\",\n                \"mlc_llm\",\n                \"convert_weight\",\n                str(model_info.model),\n                \"--quantization\",\n                model_info.quantization,\n                \"--output\",\n                output_dir,\n            ]\n            print(\" \".join(cmd), file=log_file, flush=True)\n            subprocess.run(\n                cmd,\n                check=False,\n                stdout=log_file,\n                stderr=subprocess.STDOUT,\n                env=os.environ,\n            )\n        logger.info(\"[MLC] Complete!\")\n    if not (Path(output_dir) / \"tensor-cache.json\").exists() and not model_info.gen_config_only:\n        logger.error(\n            \"[%s] Model %s. Quantization %s. No weights metadata found.\",\n            red(\"FAILED\"),\n            model_info.model_id,\n            model_info.quantization,\n        )\n        succeeded = False\n    logger.info(\"[HF] Uploading to: https://huggingface.co/%s\", repo)\n    for _retry in range(10):\n        try:\n            api.upload_folder(\n                folder_path=output_dir,\n                repo_id=repo,\n                ignore_patterns=[\"logs.txt\"],\n            )\n        except Exception as exc:  # pylint: disable=broad-except\n            logger.error(\"[%s] %s. Retrying...\", red(\"FAILED\"), exc)\n        else:\n            break\n    else:\n        raise RuntimeError(\"Failed to upload to HuggingFace Hub with 10 retries\")\n    return succeeded\n\n\ndef _get_current_log(log: str) -> ModelDeliveryList:\n    log_path = Path(log)\n    if not log_path.exists():\n        with log_path.open(\"w\", encoding=\"utf-8\") as o_f:\n            current_log = ModelDeliveryList(tasks=[])\n            json.dump(current_log.to_json(), o_f, indent=4)\n    else:\n        with log_path.open(\"r\", encoding=\"utf-8\") as i_f:\n            current_log = ModelDeliveryList.from_json(json.load(i_f))\n    return current_log\n\n\ndef _generate_model_delivery_diff(  # pylint: disable=too-many-locals\n    spec: ModelDeliveryList, log: ModelDeliveryList\n) -> ModelDeliveryList:\n    diff_tasks = []\n    default_quantization = spec.default_quantization\n    default_overrides = spec.default_overrides\n\n    for task in spec.tasks:\n        model_id = task.model_id\n        conv_template = task.conv_template\n        quantization = task.quantization\n        overrides = {**default_overrides, **task.overrides}\n\n        logger.info(\n            \"Checking task: %s %s %s %s\",\n            model_id,\n            conv_template,\n            quantization,\n            overrides,\n        )\n        log_tasks = [t for t in log.tasks if t.model_id == model_id]\n        delivered_quantizations = set()\n        gen_config_only = set()\n\n        for log_task in log_tasks:\n            log_quantization = log_task.quantization\n            assert isinstance(log_quantization, str)\n            log_override = log_task.overrides.get(log_quantization, OverrideConfigs())\n            override = overrides.get(log_quantization, OverrideConfigs())\n            if log_override == override:\n                if log_task.conv_template == conv_template:\n                    delivered_quantizations.add(log_quantization)\n                else:\n                    gen_config_only.add(log_quantization)\n\n        all_quantizations = set(default_quantization) | set(quantization)\n        quantization_diff = all_quantizations - set(delivered_quantizations)\n\n        if quantization_diff:\n            for q in quantization_diff:\n                logger.info(\"Adding task %s %s %s to the diff.\", model_id, conv_template, q)\n                task_copy = task.model_copy()\n                task_copy.quantization = [q]\n                task_copy.overrides = {q: overrides.get(q, OverrideConfigs())}\n                task_copy.gen_config_only = task_copy.gen_config_only or q in gen_config_only\n                diff_tasks.append(task_copy)\n        else:\n            logger.info(\"Task %s %s %s is up-to-date.\", model_id, conv_template, quantization)\n\n    diff_config = spec.model_copy()\n    diff_config.default_quantization = []\n    diff_config.default_overrides = {}\n    diff_config.tasks = diff_tasks\n\n    logger.info(\n        \"Model delivery diff: %s\",\n        diff_config.model_dump_json(indent=4, exclude_none=True),\n    )\n\n    return diff_config\n\n\ndef _main(  # pylint: disable=too-many-locals, too-many-arguments\n    username: str,\n    api: HfApi,\n    spec: ModelDeliveryList,\n    log: str,\n    hf_local_dir: Optional[str],\n    output: str,\n    dry_run: bool,\n):\n    delivery_diff = _generate_model_delivery_diff(spec, _get_current_log(log))\n    if dry_run:\n        logger.info(\"Dry run. No actual delivery.\")\n        return\n\n    failed_cases: List[Tuple[str, str]] = []\n    delivered_log = _get_current_log(log)\n    for task_index, task in enumerate(delivery_diff.tasks, 1):\n        logger.info(  # pylint: disable=logging-not-lazy\n            bold(\"[{task_index}/{total_tasks}] Processing model: \").format(\n                task_index=task_index,\n                total_tasks=len(delivery_diff.tasks),\n            )\n            + green(task.model_id)\n        )\n        model = _clone_repo(task.model, hf_local_dir)\n\n        quantizations = []\n\n        if delivery_diff.default_quantization:\n            quantizations += delivery_diff.default_quantization\n\n        if task.quantization:\n            if isinstance(task.quantization, str):\n                quantizations.append(task.quantization)\n            else:\n                quantizations += task.quantization\n\n        default_destination = (\n            delivery_diff.default_destination or \"{username}/{model_id}-{quantization}-MLC\"\n        )\n        for quantization in quantizations:\n            repo = default_destination.format(\n                username=username,\n                model_id=task.model_id,\n                quantization=quantization,\n            )\n            model_info = ModelDeliveryTask(\n                model=model,\n                quantization=quantization,\n                destination=repo,\n                **task.model_dump(exclude_none=True, exclude={\"model\", \"quantization\"}),\n            )\n            logger.info(\"Model info: %s\", model_info.model_dump_json(indent=4))\n            output_dir = os.path.join(\n                output, f\"{model_info.model_id}-{model_info.quantization}-MLC\"\n            )\n            if not os.path.exists(output_dir):\n                os.makedirs(output_dir)\n\n            result = _run_quantization(\n                model_info=model_info,\n                repo=repo,\n                api=api,\n                output_dir=output_dir,\n            )\n            if not result:\n                failed_cases.append(\n                    (task.model_id, quantization),\n                )\n            else:\n                delivered_log.tasks = [\n                    task\n                    for task in delivered_log.tasks\n                    if task.model_id != model_info.model_id\n                    or task.quantization != model_info.quantization\n                ]\n                delivered_log.tasks.append(model_info)\n    if failed_cases:\n        logger.info(\"Total %s %s:\", len(failed_cases), red(\"failures\"))\n        for model_id, quantization in failed_cases:\n            logger.info(\"  Model %s. Quantization %s.\", model_id, quantization)\n\n    delivered_log.tasks.sort(key=lambda task: task.model_id)\n    logger.info(\"Writing log to %s\", log)\n    with open(log, \"w\", encoding=\"utf-8\") as o_f:\n        json.dump(delivered_log.to_json(), o_f, indent=4)\n\n\ndef main():\n    \"\"\"Entry point.\"\"\"\n\n    def _load_spec(path_spec: str) -> ModelDeliveryList:\n        path = Path(path_spec)\n        if not path.exists():\n            raise argparse.ArgumentTypeError(f\"Spec file does not exist: {path}\")\n        with path.open(\"r\", encoding=\"utf-8\") as i_f:\n            return ModelDeliveryList.from_json(json.load(i_f))\n\n    def _get_default_hf_token() -> str:\n        # Try to get the token from the environment variable\n        hf_token = os.getenv(\"HF_TOKEN\")\n        if hf_token:\n            logger.info(\"HF token found in environment variable HF_TOKEN\")\n            return hf_token\n\n        # If not found, look for the token in the default cache folder\n        token_file_path = os.path.expanduser(\"~/.cache/huggingface/token\")\n        if os.path.exists(token_file_path):\n            with open(token_file_path, \"r\", encoding=\"utf-8\") as token_file:\n                hf_token = token_file.read().strip()\n                if hf_token:\n                    logger.info(\"HF token found in ~/.cache/huggingface/token\")\n                    return hf_token\n\n        raise EnvironmentError(\"HF token not found\")\n\n    parser = ArgumentParser(\"MLC LLM continuous model delivery\")\n    parser.add_argument(\n        \"--username\",\n        type=str,\n        required=True,\n        help=\"HuggingFace username\",\n    )\n    parser.add_argument(\n        \"--token\",\n        type=str,\n        default=_get_default_hf_token(),\n        help=\"HuggingFace access token, obtained under https://huggingface.co/settings/tokens\",\n    )\n    parser.add_argument(\n        \"--spec\",\n        type=_load_spec,\n        default=\"model-delivery-config.json\",\n        help=\"Path to the model delivery file\" + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--log\",\n        type=str,\n        default=\"model-delivered-log.json\",\n        help=\"Path to the output log file\" + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--output\",\n        type=str,\n        required=True,\n        help=\"Directory to store the output MLC models\",\n    )\n    parser.add_argument(\n        \"--hf-local-dir\",\n        type=str,\n        required=False,\n        help=\"Local directory to store the downloaded HuggingFace model\",\n    )\n    parser.add_argument(\n        \"--dry-run\",\n        action=\"store_true\",\n        help=\"Dry run without uploading to HuggingFace Hub\",\n    )\n    parsed = parser.parse_args()\n    _main(\n        parsed.username,\n        spec=parsed.spec,\n        log=parsed.log,\n        api=HfApi(token=parsed.token),\n        hf_local_dir=parsed.hf_local_dir,\n        output=parsed.output,\n        dry_run=parsed.dry_run,\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/mlc_llm/cli/disco_remote_socket_session.py",
    "content": "# Licensed to the Apache Software Foundation (ASF) under one\n# or more contributor license agreements.  See the NOTICE file\n# distributed with this work for additional information\n# regarding copyright ownership.  The ASF licenses this file\n# to you under the Apache License, Version 2.0 (the\n# \"License\"); you may not use this file except in compliance\n# with the License.  You may obtain a copy of the License at\n#\n#   http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied.  See the License for the\n# specific language governing permissions and limitations\n# under the License.\n\"\"\"Internal remote disco socket session.\"\"\"\n\nimport sys\n\nfrom tvm import runtime as _  # pylint: disable=unused-import\nfrom tvm_ffi import get_global_func\n\nfrom .. import base  # pylint: disable=unused-import, no-name-in-module\n\nif __name__ == \"__main__\":\n    if len(sys.argv) != 4:\n        print(\"Usage: <server_host> <server_port> <num_workers>\")\n        sys.exit(1)\n\n    server_host = sys.argv[1]\n    server_port = int(sys.argv[2])\n    num_workers = int(sys.argv[3])\n    func = get_global_func(\"runtime.disco.RemoteSocketSession\")\n    func(server_host, server_port, num_workers)\n"
  },
  {
    "path": "python/mlc_llm/cli/gen_config.py",
    "content": "\"\"\"Command line entrypoint of configuration generation.\"\"\"\n\nfrom pathlib import Path\nfrom typing import Union\n\nfrom mlc_llm.interface.gen_config import CONV_TEMPLATES, gen_config\nfrom mlc_llm.interface.help import HELP\nfrom mlc_llm.model import MODELS\nfrom mlc_llm.quantization import QUANTIZATION\nfrom mlc_llm.support.argparse import ArgumentParser\nfrom mlc_llm.support.auto_config import detect_config, detect_model_type\n\n\ndef main(argv):\n    \"\"\"Parse command line argumennts and call `mlc_llm.compiler.gen_config`.\"\"\"\n    parser = ArgumentParser(\"MLC LLM Configuration Generator\")\n\n    def _parse_output(path: Union[str, Path]) -> Path:\n        path = Path(path)\n        if not path.is_dir():\n            path.mkdir(parents=True, exist_ok=True)\n        return path\n\n    parser.add_argument(\n        \"config\",\n        type=detect_config,\n        help=HELP[\"config\"] + \" (required)\",\n    )\n    parser.add_argument(\n        \"--quantization\",\n        type=str,\n        required=True,\n        choices=list(QUANTIZATION.keys()),\n        help=HELP[\"quantization\"] + \" (required, choices: %(choices)s)\",\n    )\n    parser.add_argument(\n        \"--model-type\",\n        type=str,\n        default=\"auto\",\n        choices=[\"auto\"] + list(MODELS.keys()),\n        help=HELP[\"model_type\"] + ' (default: \"%(default)s\", choices: %(choices)s)',\n    )\n    parser.add_argument(\n        \"--conv-template\",\n        type=str,\n        required=True,\n        choices=list(CONV_TEMPLATES),\n        help=HELP[\"conv_template\"] + \" (required, choices: %(choices)s)\",\n    )\n    parser.add_argument(\n        \"--context-window-size\",\n        type=int,\n        default=None,\n        help=HELP[\"context_window_size\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--sliding-window-size\",\n        type=int,\n        default=None,\n        help=HELP[\"sliding_window_size\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--prefill-chunk-size\",\n        type=int,\n        default=None,\n        help=HELP[\"prefill_chunk_size\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--attention-sink-size\",\n        type=int,\n        default=None,\n        help=HELP[\"attention_sink_size\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--tensor-parallel-shards\",\n        type=int,\n        default=None,\n        help=HELP[\"tensor_parallel_shards\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--pipeline-parallel-stages\",\n        type=int,\n        default=None,\n        help=HELP[\"pipeline_parallel_stages\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--disaggregation\",\n        type=bool,\n        default=None,\n        help=HELP[\"disaggregation\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--max-batch-size\",\n        type=int,\n        default=128,\n        help=HELP[\"max_batch_size\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--output\",\n        \"-o\",\n        type=_parse_output,\n        required=True,\n        help=HELP[\"output_gen_mlc_chat_config\"] + \" (required)\",\n    )\n    parsed = parser.parse_args(argv)\n    model = detect_model_type(parsed.model_type, parsed.config)\n    gen_config(\n        config=parsed.config,\n        model=model,\n        quantization=QUANTIZATION[parsed.quantization],\n        conv_template=parsed.conv_template,\n        context_window_size=parsed.context_window_size,\n        sliding_window_size=parsed.sliding_window_size,\n        prefill_chunk_size=parsed.prefill_chunk_size,\n        attention_sink_size=parsed.attention_sink_size,\n        tensor_parallel_shards=parsed.tensor_parallel_shards,\n        pipeline_parallel_stages=parsed.pipeline_parallel_stages,\n        disaggregation=parsed.disaggregation,\n        max_batch_size=parsed.max_batch_size,\n        output=parsed.output,\n    )\n"
  },
  {
    "path": "python/mlc_llm/cli/lib_delivery.py",
    "content": "\"\"\"Continuous model delivery for MLC LLM models.\"\"\"\n\nimport argparse\nimport dataclasses\nimport json\nimport os\nimport shutil\nimport subprocess\nimport sys\nimport tempfile\nfrom pathlib import Path\nfrom typing import Any, Callable, Dict, List\n\nfrom mlc_llm.support import logging\nfrom mlc_llm.support.argparse import ArgumentParser\nfrom mlc_llm.support.constants import MLC_TEMP_DIR\nfrom mlc_llm.support.style import bold, green, red\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass ModelInfo:  # pylint: disable=too-many-instance-attributes\n    \"\"\"Necessary information for the model delivery\"\"\"\n\n    model_id: str\n    model: Path\n    quantization: str\n    device: str\n    # overrides the `context_window_size`, `prefill_chunk_size`,\n    # `sliding_window_size`, `attention_sink_size`, `max_batch_size`\n    # and `tensor_parallel_shards in mlc-chat-config.json\n    overrides: Dict[str, int]\n\n\nclass DeferredScope:\n    \"\"\"A context manager that defers execution of functions until exiting the scope.\"\"\"\n\n    def __init__(self):\n        self.deferred_functions = []\n\n    def add(self, func: Callable[[], None]):\n        \"\"\"Add a function to be executed when exiting the scope.\"\"\"\n        self.deferred_functions.append(func)\n\n    def __enter__(self):\n        return self\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        for func in reversed(self.deferred_functions):\n            func()\n        return False\n\n    def create_temp_dir(self) -> Path:\n        \"\"\"Create a temporary directory that will be deleted when exiting the scope.\"\"\"\n        temp_dir = tempfile.mkdtemp(dir=MLC_TEMP_DIR)\n        self.add(lambda: shutil.rmtree(temp_dir, ignore_errors=True))\n        return Path(temp_dir)\n\n\ndef _run_compilation(model_info: ModelInfo, repo_dir: Path) -> bool:\n    \"\"\"Run the compilation of the model library.\"\"\"\n\n    def get_lib_ext(device: str) -> str:\n        if device in [\"cuda\", \"vulkan\", \"metal\"]:\n            return \".so\"\n        if device in [\"android\", \"ios\"]:\n            return \".tar\"\n        if device in [\"webgpu\"]:\n            return \".wasm\"\n\n        return \"\"\n\n    succeeded = True\n    with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as temp_dir:\n        log_path = Path(temp_dir) / \"logs.txt\"\n        model_lib_name = f\"{model_info.model_id}-{model_info.quantization}-{model_info.device}\"\n        lib_ext = get_lib_ext(model_info.device)\n        if lib_ext == \"\":\n            raise ValueError(f\"Unsupported device: {model_info.device}\")\n        model_lib_name += lib_ext\n        with log_path.open(\"a\", encoding=\"utf-8\") as log_file:\n            overrides = \";\".join(f\"{key}={value}\" for key, value in model_info.overrides.items())\n            cmd = [\n                sys.executable,\n                \"-m\",\n                \"mlc_llm\",\n                \"compile\",\n                str(model_info.model),\n                \"--device\",\n                model_info.device,\n                \"--quantization\",\n                model_info.quantization,\n                \"--overrides\",\n                overrides,\n                \"--output\",\n                os.path.join(temp_dir, model_lib_name),\n            ]\n            print(\" \".join(cmd), file=log_file, flush=True)\n            subprocess.run(cmd, check=True, stdout=log_file, stderr=subprocess.STDOUT)\n            logger.info(\"[MLC] Compilation Complete!\")\n        if not (Path(temp_dir) / model_lib_name).exists():\n            logger.error(\n                \"[%s] Model %s. Device %s. No compiled library found.\",\n                red(\"FAILED\"),\n                model_info.model_id,\n                model_info.device,\n            )\n            succeeded = False\n            return succeeded\n\n        # overwrite git repo file with the compiled library\n        repo_filepath = repo_dir / model_info.model_id / model_lib_name\n        if not repo_filepath.parent.exists():\n            repo_filepath.parent.mkdir(parents=True, exist_ok=True)\n        # copy lib from Path(temp_dir) / model_lib_name to repo_filepath\n        shutil.copy(Path(temp_dir) / model_lib_name, repo_filepath)\n        logger.info(\"Saved library %s at %s\", model_lib_name, repo_filepath)\n    return succeeded\n\n\ndef _main(  # pylint: disable=too-many-locals\n    spec: Dict[str, Any],\n):\n    \"\"\"Compile the model libs in the spec and save them to the binary_libs_dir.\"\"\"\n    failed_cases: List[Any] = []\n    for task_index, task in enumerate(spec[\"tasks\"], 1):\n        logger.info(  # pylint: disable=logging-not-lazy\n            bold(\"[{task_index}/{total_tasks}] Processing model: \").format(\n                task_index=task_index,\n                total_tasks=len(spec[\"tasks\"]),\n            )\n            + green(task[\"model_id\"])\n        )\n        model_info = {\n            \"model_id\": task[\"model_id\"],\n            \"model\": task[\"model\"],\n        }\n        for compile_opt in spec[\"default_compile_options\"] + task.get(\"compile_options\", []):\n            for quantization in spec[\"default_quantization\"] + task.get(\"quantization\", []):\n                model_info[\"quantization\"] = quantization\n                model_info[\"device\"] = compile_opt[\"device\"]\n                model_info[\"overrides\"] = compile_opt.get(\"overrides\", {})\n                logger.info(\n                    \"[Config] \"\n                    + bold(\"model_id: \")\n                    + model_info[\"model_id\"]\n                    + bold(\", quantization: \")\n                    + model_info[\"quantization\"]\n                    + bold(\", device: \")\n                    + model_info[\"device\"]\n                    + bold(\", overrides: \")\n                    + json.dumps(model_info[\"overrides\"])\n                )\n\n                result = _run_compilation(\n                    ModelInfo(**model_info),\n                    repo_dir=Path(spec[\"binary_libs_dir\"]),\n                )\n                if not result:\n                    failed_cases.append(model_info)\n\n    if failed_cases:\n        logger.info(\"Total %s %s:\", len(failed_cases), red(\"failures\"))\n        for case in failed_cases:\n            logger.info(\n                \"model_id %s, quantization %s, device %s, overrides %s\",\n                case[\"model_id\"],\n                case[\"quantization\"],\n                case[\"device\"],\n                json.dumps(case[\"overrides\"]),\n            )\n\n\ndef main():\n    \"\"\"Entry point.\"\"\"\n\n    def _load_spec(path_spec: str) -> Dict[str, Any]:\n        path = Path(path_spec)\n        if not path.exists():\n            raise argparse.ArgumentTypeError(f\"Spec file does not exist: {path}\")\n        with path.open(\"r\", encoding=\"utf-8\") as i_f:\n            return json.load(i_f)\n\n    parser = ArgumentParser(\"MLC LLM continuous library delivery\")\n    parser.add_argument(\n        \"--spec\",\n        type=_load_spec,\n        required=True,\n        help=\"Path to the spec file\",\n    )\n    parsed = parser.parse_args()\n    _main(\n        spec=parsed.spec,\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/mlc_llm/cli/model_metadata.py",
    "content": "\"\"\"A tool that inspects the metadata of a model lib.\"\"\"\n\nimport json\nimport math\nfrom dataclasses import asdict\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Union\n\nfrom tvm.runtime import DataType\n\nfrom mlc_llm.support import logging\nfrom mlc_llm.support.argparse import ArgumentParser\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import green, red\n\nlogger = logging.getLogger(__name__)\n\n\ndef _extract_metadata(model_lib: Path) -> Dict[str, Any]:\n    # pylint: disable=import-outside-toplevel\n    from tvm.runtime import device, load_module\n    from tvm.runtime.vm import VirtualMachine\n\n    # pylint: enable=import-outside-toplevel\n\n    return json.loads(VirtualMachine(load_module(model_lib), device(\"cpu\"))[\"_metadata\"]())\n\n\ndef _report_all(metadata: Dict[str, Any]) -> None:\n    # Print JSON with aesthetic values that packs each parameter into one line,\n    # while keeping the rest indented.\n    indent = 2\n    indents = \" \" * indent\n    params = metadata.pop(\"params\")\n    params = indents * 2 + (\",\\n\" + indents * 2).join(json.dumps(p) for p in params)\n    lines = json.dumps(\n        metadata,\n        sort_keys=True,\n        indent=indent,\n    ).splitlines()\n    lines.insert(1, indents + '\"params\": [\\n' + params + \"\\n\" + indents + \"],\")\n    beautified_json = \"\\n\".join(lines)\n    print(beautified_json)\n\n\ndef _read_dynamic_shape(shape: List[Union[int, str]], config: Union[Dict, ConfigBase]) -> List[int]:\n    if isinstance(config, ConfigBase):\n        config = asdict(config)\n    param_shape = []\n    for s in shape:\n        if isinstance(s, int):\n            param_shape.append(s)\n        else:\n            if config is None:\n                logger.error(\n                    \"%s: Encountered dynamic shape %s, need to specify `--mlc-chat-config` for \"\n                    + \"memory usage calculation.\",\n                    red(\"FAILED\"),\n                    red(s),\n                )\n                raise AttributeError\n            if not s in config:\n                logger.error(\n                    \"%s to retrieve concrete %s for dynamic shape from %s.\",\n                    red(\"FAILED\"),\n                    red(s),\n                    config,\n                )\n                raise KeyError\n            param_shape.append(config[s])\n    return param_shape\n\n\ndef _compute_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBase]):\n    params_bytes = 0.0\n    for param in metadata[\"params\"]:\n        if all(isinstance(v, int) for v in param[\"shape\"]):\n            assert all(v > 0 for v in param[\"shape\"]), \"All shapes should be strictly positive.\"\n            param_shape = param[\"shape\"]\n        else:\n            # Contains dynamic shape; use config to look up concrete values\n            param_shape = _read_dynamic_shape(param[\"shape\"], config)\n        params_bytes += math.prod(param_shape) * DataType(param[\"dtype\"]).itemsize\n    temp_func_bytes = 0.0\n    for _func_name, func_bytes in metadata[\"memory_usage\"].items():\n        temp_func_bytes = max(temp_func_bytes, func_bytes)\n\n    return params_bytes, temp_func_bytes\n\n\ndef _report_memory_usage(metadata: Dict[str, Any], config: Union[Dict, ConfigBase]) -> None:\n    params_bytes, temp_func_bytes = _compute_memory_usage(metadata, config)\n    total_size = params_bytes + temp_func_bytes\n    logger.info(\n        \"%s: %.2f MB (Parameters: %.2f MB. Temporary buffer: %.2f MB)\",\n        green(\"Total memory usage without KV cache\"),\n        total_size / 1024 / 1024,\n        params_bytes / 1024 / 1024,\n        temp_func_bytes / 1024 / 1024,\n    )\n\n    # Compute KV cache size per token of context window.\n    if isinstance(config, ConfigBase):\n        config = asdict(config)\n    if (\n        \"head_dim\" in config\n        and \"num_hidden_layers\" in config\n        and \"num_key_value_heads\" in config\n        and \"quantization\" in metadata\n    ):\n        quantization_type = metadata[\"quantization\"]\n        dtype_bytes = None\n        if \"f32\" in quantization_type:\n            dtype_bytes = 4\n        elif \"bf16\" in quantization_type:\n            dtype_bytes = 2\n        elif \"f16\" in quantization_type:\n            dtype_bytes = 2\n        # TODO: If support quantized KV in future, need to change this  # pylint: disable=fixme\n        if dtype_bytes is not None:\n            bytes_per_token = (\n                config[\"head_dim\"]\n                * config[\"num_hidden_layers\"]\n                * config[\"num_key_value_heads\"]\n                * dtype_bytes\n                * 2  # 2 for key and value\n            )\n            logger.info(\n                \"%s: %.2f MB per token in the context window\",\n                green(\"KV cache size\"),\n                bytes_per_token / 1024 / 1024,\n            )\n            logger.info(\n                \"%s: %.2f MB\",\n                green(\"Total memory usage with a 4K KV cache\"),\n                (total_size + bytes_per_token * 4096) / 1024 / 1024,\n            )\n\n    logger.info(\n        \"To reduce memory usage, \"\n        \"tweak `prefill_chunk_size`, `context_window_size` and `sliding_window_size`\"\n    )\n\n\ndef main():\n    \"\"\"Entry point for the model metadata tool.\"\"\"\n    parser = ArgumentParser(description=\"A tool that inspects the metadata of a model lib.\")\n    parser.add_argument(\n        \"model_lib\",\n        type=Path,\n        help=\"\"\"The compiled model library. In MLC LLM, an LLM is compiled to a shared or static\n        library (.so or .a), which contains GPU computation to efficiently run the LLM. MLC Chat,\n        as the runtime of MLC LLM, depends on the compiled model library to generate tokens.\n        \"\"\",\n    )\n    parser.add_argument(\n        \"--mlc-chat-config\",\n        type=Path,\n        help=\"\"\"The `mlc-chat-config.json` file specific to a model variant. This is only required\n        when `memory-only` is true and `model_lib` contains a dynamic parameter shape (i.e. using\n        a variable to represent the shape). For instance, `model.embed_tokens.q_weight` can have\n        shape `[\"vocab_size\", 512]`. In these cases, we look up the concrete value in\n        `mlc-chat-config.json`.\n        \"\"\",\n    )\n    parser.add_argument(\n        \"--memory-only\",\n        action=\"store_true\",\n        help=\"\"\"If set, only inspect the metadata in memory usage and print richer analysis.\n        Otherwise, the tool will load all the metadata from the model library file but only print\n        the basic information in JSON.\n        \"\"\",\n    )\n    parsed = parser.parse_args()\n    # Load metadata from model lib\n    try:\n        metadata = _extract_metadata(parsed.model_lib)\n    except:  # pylint: disable=bare-except\n        logger.exception(\"%s to read metadata section in legacy model lib.\", red(\"FAILED\"))\n        return\n    # Load mlc_chat_config if provided\n    cfg = None\n    if parsed.mlc_chat_config:\n        mlc_chat_config_path = Path(parsed.mlc_chat_config)\n        if not mlc_chat_config_path.exists():\n            raise ValueError(f\"{mlc_chat_config_path} does not exist.\")\n        with open(mlc_chat_config_path, \"r\", encoding=\"utf-8\") as config_file:\n            cfg = json.load(config_file)\n    # Main body\n    if parsed.memory_only:\n        _report_memory_usage(metadata, cfg)\n    else:\n        _report_all(metadata)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/mlc_llm/cli/package.py",
    "content": "\"\"\"Command line entrypoint of package.\"\"\"\n\nimport os\nfrom pathlib import Path\nfrom typing import Union\n\nfrom mlc_llm.interface.help import HELP\nfrom mlc_llm.interface.package import package\nfrom mlc_llm.support.argparse import ArgumentParser\n\n\ndef main(argv):\n    \"\"\"Parse command line arguments and call `mlc_llm.interface.package`.\"\"\"\n    parser = ArgumentParser(\"MLC LLM Package CLI\")\n\n    def _parse_package_config(path: Union[str, Path]) -> Path:\n        path = Path(path)\n        if not path.exists():\n            raise ValueError(\n                f\"Path {str(path)} is expected to be a JSON file, but the file does not exist.\"\n            )\n        if not path.is_file():\n            raise ValueError(f\"Path {str(path)} is expected to be a JSON file.\")\n        return path\n\n    def _parse_mlc_llm_source_dir(path: str) -> Path:\n        os.environ[\"MLC_LLM_SOURCE_DIR\"] = path\n        return Path(path)\n\n    def _parse_output(path: Union[str, Path]) -> Path:\n        path = Path(path)\n        if not path.is_dir():\n            path.mkdir(parents=True, exist_ok=True)\n        return path\n\n    parser.add_argument(\n        \"--package-config\",\n        type=_parse_package_config,\n        default=\"mlc-package-config.json\",\n        help=HELP[\"config_package\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--mlc-llm-source-dir\",\n        type=_parse_mlc_llm_source_dir,\n        default=os.environ.get(\"MLC_LLM_SOURCE_DIR\", None),\n        help=HELP[\"mlc_llm_source_dir\"]\n        + \" (default: the $MLC_LLM_SOURCE_DIR environment variable)\",\n    )\n    parser.add_argument(\n        \"--output\",\n        \"-o\",\n        type=_parse_output,\n        default=\"dist\",\n        help=HELP[\"output_package\"] + ' (default: \"%(default)s\")',\n    )\n    parsed = parser.parse_args(argv)\n    if parsed.mlc_llm_source_dir is None:\n        raise ValueError(\n            \"MLC LLM home is not specified. \"\n            \"Please obtain a copy of MLC LLM source code by \"\n            \"cloning https://github.com/mlc-ai/mlc-llm, and set environment variable \"\n            '\"MLC_LLM_SOURCE_DIR=path/to/mlc-llm\"'\n        )\n    package(\n        package_config_path=parsed.package_config,\n        mlc_llm_source_dir=parsed.mlc_llm_source_dir,\n        output=parsed.output,\n    )\n"
  },
  {
    "path": "python/mlc_llm/cli/router.py",
    "content": "\"\"\"Command line entrypoint of router.\"\"\"\n\nfrom mlc_llm.interface.help import HELP\nfrom mlc_llm.interface.router import serve\nfrom mlc_llm.support.argparse import ArgumentParser\n\n\ndef main(argv):\n    \"\"\"Parse command line arguments and call `mlc_llm.interface.router`.\"\"\"\n\n    # Define a custom argument type for a list of strings\n    def list_of_strings(arg):\n        return arg.split(\",\")\n\n    parser = ArgumentParser(\"MLC LLM Router Serve CLI\")\n    parser.add_argument(\n        \"model\",\n        type=str,\n        help=HELP[\"model\"] + \" (required)\",\n    )\n    parser.add_argument(\n        \"--model-lib\",\n        type=str,\n        default=None,\n        help=HELP[\"model_lib\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--router-mode\",\n        type=str,\n        choices=[\"disagg\", \"round-robin\"],\n        default=\"disagg\",\n        help=\"router mode\" + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--router-host\",\n        type=str,\n        default=\"127.0.0.1\",\n        help=\"router host\" + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--router-port\",\n        type=int,\n        default=8000,\n        help=\"router port\" + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--endpoint-hosts\",\n        type=list_of_strings,\n        default=\"127.0.0.1\",\n        help=\"Host of each endpoint, separated by comma.\" + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--endpoint-ports\",\n        nargs=\"*\",\n        type=int,\n        default=[8080],\n        help=\"Port of each endpoint, separated by space.\" + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--endpoint-num-gpus\",\n        nargs=\"*\",\n        type=int,\n        default=[1],\n        help=\"Number of GPUs of each endpoint, separated by space.\" + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--enable-prefix-cache\",\n        default=False,\n        action=\"store_true\",\n        help=\"whether to enable prefix cache\" + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--pd-balance-factor\",\n        type=float,\n        default=0.0,\n        help=HELP[\"pd_balance_factor\"] + ' (default: \"%(default)s\")',\n    )\n    parsed = parser.parse_args(argv)\n    serve(\n        model=parsed.model,\n        model_lib=parsed.model_lib,\n        router_host=parsed.router_host,\n        router_port=parsed.router_port,\n        endpoint_hosts=parsed.endpoint_hosts,\n        endpoint_ports=parsed.endpoint_ports,\n        endpoint_num_gpus=parsed.endpoint_num_gpus,\n        enable_prefix_cache=parsed.enable_prefix_cache,\n        router_mode=parsed.router_mode,\n        pd_balance_factor=parsed.pd_balance_factor,\n    )\n"
  },
  {
    "path": "python/mlc_llm/cli/serve.py",
    "content": "\"\"\"Command line entrypoint of serve.\"\"\"\n\nimport dataclasses\nimport json\nfrom io import StringIO\nfrom typing import Literal, Optional\n\nfrom mlc_llm.interface.help import HELP\nfrom mlc_llm.interface.serve import serve\nfrom mlc_llm.support import argparse\nfrom mlc_llm.support.argparse import ArgumentParser\n\n\n@dataclasses.dataclass\nclass EngineConfigOverride:  # pylint: disable=too-many-instance-attributes\n    \"\"\"Arguments for overriding engine config.\"\"\"\n\n    # Overrides for EngineConfig (runtime)\n    max_num_sequence: Optional[int] = None\n    max_total_seq_length: Optional[int] = None\n    prefill_chunk_size: Optional[int] = None\n    max_history_size: Optional[int] = None\n    gpu_memory_utilization: Optional[float] = None\n    spec_draft_length: Optional[int] = None\n    spec_tree_width: Optional[int] = None\n    prefix_cache_mode: Optional[Literal[\"disable\", \"radix\"]] = None\n    prefix_cache_max_num_recycling_seqs: Optional[int] = None\n    prefill_mode: Optional[Literal[\"chunked\", \"hybrid\"]] = None\n    context_window_size: Optional[int] = None\n    sliding_window_size: Optional[int] = None\n    attention_sink_size: Optional[int] = None\n    tensor_parallel_shards: Optional[int] = None\n    pipeline_parallel_stages: Optional[int] = None\n    opt: Optional[str] = None\n\n    def __repr__(self) -> str:\n        out = StringIO()\n        print(f\"max_num_sequence={self.max_num_sequence}\", file=out, end=\"\")\n        print(f\";max_total_seq_length={self.max_total_seq_length}\", file=out, end=\"\")\n        print(f\";prefill_chunk_size={self.prefill_chunk_size}\", file=out, end=\"\")\n        print(f\";max_history_size={self.max_history_size}\", file=out, end=\"\")\n        print(f\";gpu_memory_utilization={self.gpu_memory_utilization}\", file=out, end=\"\")\n        print(f\";spec_draft_length={self.spec_draft_length}\", file=out, end=\"\")\n        print(f\";spec_tree_width={self.spec_tree_width}\", file=out, end=\"\")\n        print(f\";prefix_cache_mode={self.prefix_cache_mode}\", file=out, end=\"\")\n        print(\n            f\";prefix_cache_max_num_recycling_seqs={self.prefix_cache_max_num_recycling_seqs}\",\n            file=out,\n            end=\"\",\n        )\n        print(f\";prefill_mode={self.prefill_mode}\", file=out, end=\"\")\n        print(f\";context_window_size={self.context_window_size}\", file=out, end=\"\")\n        print(f\";sliding_window_size={self.sliding_window_size}\", file=out, end=\"\")\n        print(f\";attention_sink_size={self.attention_sink_size}\", file=out, end=\"\")\n        print(f\";tensor_parallel_shards={self.tensor_parallel_shards}\", file=out, end=\"\")\n        print(\n            f\";pipeline_parallel_stages={self.pipeline_parallel_stages}\",\n            file=out,\n            end=\"\",\n        )\n        print(f\";opt={self.opt}\", file=out, end=\"\")\n        return out.getvalue().rstrip()\n\n    @staticmethod\n    def from_str(source: str) -> \"EngineConfigOverride\":\n        \"\"\"Parse engine config override values from a string.\"\"\"\n        parser = argparse.ArgumentParser(description=\"Engine config override values\")\n\n        parser.add_argument(\"--max_num_sequence\", type=int, default=None)\n        parser.add_argument(\"--max_total_seq_length\", type=int, default=None)\n        parser.add_argument(\"--prefill_chunk_size\", type=int, default=None)\n        parser.add_argument(\"--max_history_size\", type=int, default=None)\n        parser.add_argument(\"--gpu_memory_utilization\", type=float, default=None)\n        parser.add_argument(\"--spec_draft_length\", type=int, default=None)\n        parser.add_argument(\"--spec_tree_width\", type=int, default=None)\n        parser.add_argument(\"--prefix_cache_mode\", type=str, default=\"radix\")\n        parser.add_argument(\"--prefix_cache_max_num_recycling_seqs\", type=int, default=None)\n        parser.add_argument(\"--prefill_mode\", type=str, default=\"hybrid\")\n        parser.add_argument(\"--context_window_size\", type=int, default=None)\n        parser.add_argument(\"--sliding_window_size\", type=int, default=None)\n        parser.add_argument(\"--attention_sink_size\", type=int, default=None)\n        parser.add_argument(\"--tensor_parallel_shards\", type=int, default=None)\n        parser.add_argument(\"--pipeline_parallel_stages\", type=int, default=None)\n        parser.add_argument(\"--opt\", type=str, default=None)\n        results = parser.parse_args([f\"--{i}\" for i in source.split(\";\") if i])\n        return EngineConfigOverride(\n            max_num_sequence=results.max_num_sequence,\n            max_total_seq_length=results.max_total_seq_length,\n            prefill_chunk_size=results.prefill_chunk_size,\n            max_history_size=results.max_history_size,\n            gpu_memory_utilization=results.gpu_memory_utilization,\n            spec_draft_length=results.spec_draft_length,\n            spec_tree_width=results.spec_tree_width,\n            prefix_cache_mode=results.prefix_cache_mode,\n            prefix_cache_max_num_recycling_seqs=results.prefix_cache_max_num_recycling_seqs,\n            prefill_mode=results.prefill_mode,\n            context_window_size=results.context_window_size,\n            sliding_window_size=results.sliding_window_size,\n            attention_sink_size=results.attention_sink_size,\n            tensor_parallel_shards=results.tensor_parallel_shards,\n            pipeline_parallel_stages=results.pipeline_parallel_stages,\n            opt=results.opt,\n        )\n\n\ndef main(argv):\n    \"\"\"Parse command line arguments and call `mlc_llm.interface.serve`.\"\"\"\n    parser = ArgumentParser(\"MLC LLM Serve CLI\")\n\n    parser.add_argument(\n        \"model\",\n        type=str,\n        help=HELP[\"model\"] + \" (required)\",\n    )\n    parser.add_argument(\n        \"--device\",\n        type=str,\n        default=\"auto\",\n        help=HELP[\"device_deploy\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--model-lib\",\n        type=str,\n        default=None,\n        help=HELP[\"model_lib\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--mode\",\n        type=str,\n        choices=[\"local\", \"interactive\", \"server\"],\n        default=\"local\",\n        help=HELP[\"mode_serve\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--enable-debug\",\n        action=\"store_true\",\n        help=\"whether we enable debug end points and debug config when accepting requests\",\n    )\n    parser.add_argument(\n        \"--additional-models\", type=str, nargs=\"*\", help=HELP[\"additional_models_serve\"]\n    )\n    parser.add_argument(\n        \"--embedding-model\",\n        type=str,\n        default=None,\n        help=\"Path to the embedding model weight directory (enables /v1/embeddings endpoint)\",\n    )\n    parser.add_argument(\n        \"--embedding-model-lib\",\n        type=str,\n        default=None,\n        help=\"Path to the compiled embedding model library (.so/.dylib file)\",\n    )\n    parser.add_argument(\n        \"--speculative-mode\",\n        type=str,\n        choices=[\"disable\", \"small_draft\", \"eagle\", \"medusa\"],\n        default=\"disable\",\n        help=HELP[\"speculative_mode_serve\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--prefix-cache-mode\",\n        type=str,\n        choices=[\"disable\", \"radix\"],\n        default=\"radix\",\n        help=HELP[\"prefix_cache_mode_serve\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--prefill-mode\",\n        type=str,\n        choices=[\"hybrid\", \"chunked\"],\n        default=\"hybrid\",\n        help=HELP[\"prefill_mode\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--overrides\",\n        type=EngineConfigOverride.from_str,\n        default=\"\",\n        help=HELP[\"overrides_serve\"],\n    )\n    parser.add_argument(\"--enable-tracing\", action=\"store_true\", help=HELP[\"enable_tracing_serve\"])\n    parser.add_argument(\n        \"--host\",\n        type=str,\n        default=\"127.0.0.1\",\n        help=\"host name\" + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--port\",\n        type=int,\n        default=8000,\n        help=\"port\" + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\"--allow-credentials\", action=\"store_true\", help=\"allow credentials\")\n    parser.add_argument(\n        \"--allow-origins\",\n        type=json.loads,\n        default=[\"*\"],\n        help=\"allowed origins\" + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--allow-methods\",\n        type=json.loads,\n        default=[\"*\"],\n        help=\"allowed methods\" + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--allow-headers\",\n        type=json.loads,\n        default=[\"*\"],\n        help=\"allowed headers\" + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--api-key\",\n        type=str,\n        default=None,\n        help=\"API key for authentication. If not provided, authentication is disabled.\",\n    )\n    parsed = parser.parse_args(argv)\n\n    additional_models = []\n    if parsed.additional_models is not None:\n        for additional_model in parsed.additional_models:\n            splits = additional_model.split(\",\", maxsplit=1)\n            if len(splits) == 2:\n                additional_models.append((splits[0], splits[1]))\n            else:\n                additional_models.append(splits[0])\n\n    serve(\n        model=parsed.model,\n        device=parsed.device,\n        model_lib=parsed.model_lib,\n        mode=parsed.mode,\n        enable_debug=parsed.enable_debug,\n        additional_models=additional_models,\n        embedding_model=parsed.embedding_model,\n        embedding_model_lib=parsed.embedding_model_lib,\n        tensor_parallel_shards=parsed.overrides.tensor_parallel_shards,\n        pipeline_parallel_stages=parsed.overrides.pipeline_parallel_stages,\n        opt=parsed.overrides.opt,\n        speculative_mode=parsed.speculative_mode,\n        prefix_cache_mode=parsed.prefix_cache_mode,\n        max_num_sequence=parsed.overrides.max_num_sequence,\n        max_total_sequence_length=parsed.overrides.max_total_seq_length,\n        max_single_sequence_length=parsed.overrides.context_window_size,\n        prefill_chunk_size=parsed.overrides.prefill_chunk_size,\n        sliding_window_size=parsed.overrides.sliding_window_size,\n        attention_sink_size=parsed.overrides.attention_sink_size,\n        max_history_size=parsed.overrides.max_history_size,\n        gpu_memory_utilization=parsed.overrides.gpu_memory_utilization,\n        spec_draft_length=parsed.overrides.spec_draft_length,\n        spec_tree_width=parsed.overrides.spec_tree_width,\n        prefix_cache_max_num_recycling_seqs=parsed.overrides.prefix_cache_max_num_recycling_seqs,\n        prefill_mode=parsed.prefill_mode,\n        enable_tracing=parsed.enable_tracing,\n        host=parsed.host,\n        port=parsed.port,\n        allow_credentials=parsed.allow_credentials,\n        allow_origins=parsed.allow_origins,\n        allow_methods=parsed.allow_methods,\n        allow_headers=parsed.allow_headers,\n        api_key=parsed.api_key,\n    )\n"
  },
  {
    "path": "python/mlc_llm/cli/worker.py",
    "content": "# Licensed to the Apache Software Foundation (ASF) under one\n# or more contributor license agreements.  See the NOTICE file\n# distributed with this work for additional information\n# regarding copyright ownership.  The ASF licenses this file\n# to you under the Apache License, Version 2.0 (the\n# \"License\"); you may not use this file except in compliance\n# with the License.  You may obtain a copy of the License at\n#\n#   http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied.  See the License for the\n# specific language governing permissions and limitations\n# under the License.\n# pylint: disable=invalid-name\n\"\"\"Internal DiscoWorker for Disco ProcessSession.\"\"\"\n\nimport os\nimport sys\n\nfrom tvm import runtime as _  # pylint: disable=unused-import\nfrom tvm_ffi import get_global_func\n\nfrom .. import base  # pylint: disable=unused-import, no-name-in-module\n\n# register the calibration functions\nfrom ..interface import calibrate  # pylint: disable=unused-import\n\n\ndef main():\n    \"\"\"Main worker function\"\"\"\n    if len(sys.argv) != 6:\n        print(\"Usage: <worker_id> <num_workers> <num_groups> <read_fd> <write_fd>\")\n        return\n\n    worker_id = int(sys.argv[1])\n    num_workers = int(sys.argv[2])\n    num_groups = int(sys.argv[3])\n    read_fd = int(sys.argv[4])\n    write_fd = int(sys.argv[5])\n    if sys.platform == \"win32\":\n        import msvcrt  # pylint: disable=import-outside-toplevel,import-error\n\n        reader = msvcrt.open_osfhandle(read_fd, os.O_BINARY)\n        writer = msvcrt.open_osfhandle(write_fd, os.O_BINARY)\n    else:\n        reader = read_fd\n        writer = write_fd\n\n    worker_func = get_global_func(\"runtime.disco.WorkerProcess\")\n    worker_func(worker_id, num_workers, num_groups, reader, writer)\n\n\nif __name__ == \"__main__\":\n    try:\n        main()\n    except (KeyboardInterrupt, IOError):\n        pass\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/__init__.py",
    "content": "\"\"\"Compiler passes used in MLC LLM.\"\"\"\n\nfrom . import pipeline as _pipeline\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/attach_cuda_graph_alloc_init_func.py",
    "content": "\"\"\"The pass that attaches an empty function for initialization.\"\"\"\n\nimport tvm\nfrom tvm import IRModule, relax\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"AttachCUDAGraphAllocInitFunc\")\nclass AttachCUDAGraphAllocInitFunc:  # pylint: disable=too-few-public-methods\n    \"\"\"Attach an empty function for initialization.\"\"\"\n\n    def __init__(self):\n        pass\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"Entrypoint\"\"\"\n        bb = relax.BlockBuilder(mod)\n        alloc_func_gv = None\n        for gv, _ in mod.functions_items():\n            if gv.name_hint.startswith(\"cuda_graph_alloc\"):\n                assert alloc_func_gv is None\n                alloc_func_gv = gv\n        if alloc_func_gv is None:\n            return mod\n\n        with bb.function(\"cuda_graph_alloc_init\", []):\n            bb.emit_func_output(\n                relax.op.call_builtin_with_ctx(\n                    \"vm.builtin.cuda_graph.get_cached_alloc\",\n                    args=[alloc_func_gv, relax.PrimValue(0)],\n                    sinfo_args=relax.ObjectStructInfo(),\n                )\n            )\n        return bb.finalize()\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/attach_embedding_allocator.py",
    "content": "\"\"\"The pass that attaches embedding allocation function to the IRModule.\"\"\"\n\nfrom typing import Any, Dict\n\nimport tvm\nfrom tvm import IRModule, relax\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"AttachAllocEmbeddingTensorFunc\")\nclass AttachAllocEmbeddingTensorFunc:  # pylint: disable=too-few-public-methods\n    \"\"\"Attach embedding tensor allocation Relax function to IRModule.\"\"\"\n\n    def __init__(self, metadata: Dict[str, Any]):\n        self.metadata = metadata\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"Entrypoint\"\"\"\n        embed_func = None\n        for gv, func in mod.functions_items():\n            if gv.name_hint == \"embed\":\n                embed_func = func\n\n        if embed_func is None:\n            return mod\n\n        hidden_size = embed_func.ret_struct_info.shape[-1]\n        dtype = embed_func.ret_struct_info.dtype\n        bb = relax.BlockBuilder(mod)\n        with bb.function(\"alloc_embedding_tensor\", []):\n            bb.emit_func_output(\n                bb.emit(\n                    relax.op.builtin.alloc_tensor(\n                        relax.ShapeExpr([self.metadata[\"prefill_chunk_size\"], hidden_size]),\n                        dtype,\n                        runtime_device_index=0,\n                    )\n                )\n            )\n        return bb.finalize()\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/attach_logit_processor.py",
    "content": "\"\"\"The pass that attaches logit processor functions to the IRModule.\"\"\"\n\nimport tvm\nfrom tvm import IRModule\nfrom tvm.script import tir as T\n\nfrom ..support.max_thread_check import (\n    check_thread_limits,\n    get_max_num_threads_per_block,\n)\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"AttachLogitProcessFunc\")\nclass AttachLogitProcessFunc:  # pylint: disable=too-few-public-methods\n    \"\"\"Attach logit processing TIR functions to IRModule.\"\"\"\n\n    def __init__(self, target: tvm.target.Target):\n        \"\"\"Initializer.\n\n        Parameters\n        ----------\n        target : tvm.target.Target\n            The target of the model compilation.\n        \"\"\"\n        self.target = target\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"Entrypoint\"\"\"\n        mod = mod.clone()\n        if str(self.target.kind) == \"llvm\":\n            mod[\"apply_logit_bias_inplace\"] = _get_apply_logit_bias_inplace_cpu()\n            mod[\"apply_penalty_inplace\"] = _get_apply_penalty_inplace_cpu()\n            mod[\"apply_bitmask_inplace\"] = _get_apply_bitmask_inplace_cpu()\n        else:\n            mod[\"apply_logit_bias_inplace\"] = _get_apply_logit_bias_inplace(self.target)\n            mod[\"apply_penalty_inplace\"] = _get_apply_penalty_inplace(self.target)\n            mod[\"apply_bitmask_inplace\"] = _get_apply_bitmask_inplace(self.target)\n        return mod\n\n\ndef _get_apply_logit_bias_inplace_cpu():\n    @T.prim_func\n    def _apply_logit_bias_inplace(\n        var_logits: T.handle,\n        var_pos2seq_id: T.handle,\n        var_token_ids: T.handle,\n        var_logit_bias: T.handle,\n    ) -> None:\n        \"\"\"Function that applies logit bias in place.\"\"\"\n        T.func_attr(\n            {\n                \"global_symbol\": \"apply_logit_bias_inplace\",\n                \"tir.noalias\": True,\n                \"tir.is_scheduled\": True,\n            }\n        )\n        batch_size = T.int32(is_size_var=True)\n        vocab_size = T.int32(is_size_var=True)\n        num_token = T.int32(is_size_var=True)\n        logits = T.match_buffer(var_logits, (batch_size, vocab_size), \"float32\")\n        # seq_ids\n        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), \"int32\")\n        token_ids = T.match_buffer(var_token_ids, (num_token,), \"int32\")\n        logit_bias = T.match_buffer(var_logit_bias, (num_token,), \"float32\")\n\n        for i in range(num_token):\n            logits[pos2seq_id[i], token_ids[i]] += logit_bias[i]\n\n    return _apply_logit_bias_inplace\n\n\ndef _get_apply_logit_bias_inplace(target: tvm.target.Target):\n    tx = 1024  # default\n    max_num_threads_per_block = get_max_num_threads_per_block(target)\n    tx = min(tx, max_num_threads_per_block)\n    check_thread_limits(target, bdx=tx, bdy=1, bdz=1, gdz=1)\n\n    @T.prim_func\n    def _apply_logit_bias_inplace(\n        var_logits: T.handle,\n        var_pos2seq_id: T.handle,\n        var_token_ids: T.handle,\n        var_logit_bias: T.handle,\n    ) -> None:\n        \"\"\"Function that applies logit bias in place.\"\"\"\n        T.func_attr(\n            {\n                \"global_symbol\": \"apply_logit_bias_inplace\",\n                \"tir.noalias\": True,\n                \"tir.is_scheduled\": True,\n            }\n        )\n        batch_size = T.int32(is_size_var=True)\n        vocab_size = T.int32(is_size_var=True)\n        num_token = T.int32(is_size_var=True)\n        logits = T.match_buffer(var_logits, (batch_size, vocab_size), \"float32\")\n        # seq_ids\n        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), \"int32\")\n        token_ids = T.match_buffer(var_token_ids, (num_token,), \"int32\")\n        logit_bias = T.match_buffer(var_logit_bias, (num_token,), \"float32\")\n\n        for p0 in T.thread_binding(0, (num_token + tx - 1) // tx, \"blockIdx.x\"):\n            for p1 in T.thread_binding(0, tx, \"threadIdx.x\"):\n                with T.sblock(\"block\"):\n                    vp = T.axis.spatial(num_token, p0 * tx + p1)\n                    T.where(p0 * tx + p1 < num_token)\n                    logits[pos2seq_id[vp], token_ids[vp]] += logit_bias[vp]\n\n    return _apply_logit_bias_inplace\n\n\ndef _get_apply_penalty_inplace_cpu():\n    @T.prim_func\n    def _apply_penalty_inplace(  # pylint: disable=too-many-arguments,too-many-locals\n        var_logits: T.handle,\n        var_seq_ids: T.handle,\n        var_pos2seq_id: T.handle,\n        var_token_ids: T.handle,\n        var_token_cnt: T.handle,\n        var_penalties: T.handle,\n    ) -> None:\n        \"\"\"Function that applies penalties in place.\"\"\"\n        T.func_attr(\n            {\n                \"global_symbol\": \"apply_penalty_inplace\",\n                \"tir.noalias\": True,\n                \"tir.is_scheduled\": True,\n            }\n        )\n        batch_size = T.int32(is_size_var=True)\n        vocab_size = T.int32(is_size_var=True)\n        num_token = T.int32(is_size_var=True)\n        num_seq = T.int32(is_size_var=True)\n        logits = T.match_buffer(var_logits, (batch_size, vocab_size), \"float32\")\n        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), \"int32\")\n        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), \"int32\")\n        token_ids = T.match_buffer(var_token_ids, (num_token,), \"int32\")\n        token_cnt = T.match_buffer(var_token_cnt, (num_token,), \"int32\")\n        penalties = T.match_buffer(var_penalties, (num_seq, 3), \"float32\")\n\n        for token in T.serial(num_token):\n            with T.sblock(\"block\"):\n                vp = T.axis.spatial(num_token, token)\n                logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] -= (\n                    penalties[pos2seq_id[vp], 0] + token_cnt[vp] * penalties[pos2seq_id[vp], 1]\n                )\n                logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(\n                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] < T.float32(0),\n                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2],\n                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2],\n                )\n\n    return _apply_penalty_inplace\n\n\ndef _get_apply_penalty_inplace(target: tvm.target.Target):\n    tx = 1024  # default\n    max_num_threads_per_block = get_max_num_threads_per_block(target)\n    tx = min(tx, max_num_threads_per_block)\n    check_thread_limits(target, bdx=tx, bdy=1, bdz=1, gdz=1)\n\n    @T.prim_func\n    def _apply_penalty_inplace(  # pylint: disable=too-many-arguments,too-many-locals\n        var_logits: T.handle,\n        var_seq_ids: T.handle,\n        var_pos2seq_id: T.handle,\n        var_token_ids: T.handle,\n        var_token_cnt: T.handle,\n        var_penalties: T.handle,\n    ) -> None:\n        \"\"\"Function that applies penalties in place.\"\"\"\n        T.func_attr(\n            {\n                \"global_symbol\": \"apply_penalty_inplace\",\n                \"tir.noalias\": True,\n                \"tir.is_scheduled\": True,\n            }\n        )\n        batch_size = T.int32(is_size_var=True)\n        vocab_size = T.int32(is_size_var=True)\n        num_token = T.int32(is_size_var=True)\n        num_seq = T.int32(is_size_var=True)\n        logits = T.match_buffer(var_logits, (batch_size, vocab_size), \"float32\")\n        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), \"int32\")\n        pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), \"int32\")\n        token_ids = T.match_buffer(var_token_ids, (num_token,), \"int32\")\n        token_cnt = T.match_buffer(var_token_cnt, (num_token,), \"int32\")\n        penalties = T.match_buffer(var_penalties, (num_seq, 3), \"float32\")\n\n        for p0 in T.thread_binding(0, (num_token + tx - 1) // tx, \"blockIdx.x\"):\n            for p1 in T.thread_binding(0, tx, \"threadIdx.x\"):\n                with T.sblock(\"block\"):\n                    vp = T.axis.spatial(num_token, p0 * tx + p1)\n                    T.where(p0 * tx + p1 < num_token)\n                    # Penalties: (presence_penalty, frequency_penalty, repetition_penalty)\n                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] -= (\n                        penalties[pos2seq_id[vp], 0] + token_cnt[vp] * penalties[pos2seq_id[vp], 1]\n                    )\n                    logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(\n                        logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] < T.float32(0),\n                        logits[seq_ids[pos2seq_id[vp]], token_ids[vp]]\n                        * penalties[pos2seq_id[vp], 2],\n                        logits[seq_ids[pos2seq_id[vp]], token_ids[vp]]\n                        / penalties[pos2seq_id[vp], 2],\n                    )\n\n    return _apply_penalty_inplace\n\n\ndef _get_apply_bitmask_inplace_cpu():\n    @T.prim_func\n    def _apply_bitmask_inplace(\n        var_logits: T.handle,\n        var_seq_ids: T.handle,\n        var_bitmask: T.handle,\n    ) -> None:\n        \"\"\"Function that applies vocabulary masking in place.\"\"\"\n        T.func_attr(\n            {\n                \"global_symbol\": \"apply_bitmask_inplace\",\n                \"tir.noalias\": True,\n                \"tir.is_scheduled\": True,\n            }\n        )\n        batch_size = T.int32(is_size_var=True)\n        vocab_size = T.int32(is_size_var=True)\n        num_seq = T.int32(is_size_var=True)\n        logits = T.match_buffer(var_logits, (batch_size, vocab_size), \"float32\")\n        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), \"int32\")\n        bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), \"int32\")\n\n        for token in T.serial(num_seq * vocab_size):\n            with T.sblock(\"block\"):\n                vs = T.axis.spatial(num_seq, (token) // vocab_size)\n                vv = T.axis.spatial(vocab_size, (token) % vocab_size)\n\n                logits[seq_ids[vs], vv] = T.if_then_else(\n                    (bitmask[seq_ids[vs], vv // 32] >> (vv % 32)) & 1 == 1,\n                    logits[seq_ids[vs], vv],\n                    T.min_value(\"float32\"),\n                )\n\n    return _apply_bitmask_inplace\n\n\ndef _get_apply_bitmask_inplace(target: tvm.target.Target):\n    tx = 1024  # default\n    max_num_threads_per_block = get_max_num_threads_per_block(target)\n    tx = min(tx, max_num_threads_per_block)\n    check_thread_limits(target, bdx=tx, bdy=1, bdz=1, gdz=1)\n\n    @T.prim_func\n    def _apply_bitmask_inplace(\n        var_logits: T.handle,\n        var_seq_ids: T.handle,\n        var_bitmask: T.handle,\n    ) -> None:\n        \"\"\"Function that applies vocabulary masking in place.\"\"\"\n        T.func_attr(\n            {\n                \"global_symbol\": \"apply_bitmask_inplace\",\n                \"tir.noalias\": True,\n                \"tir.is_scheduled\": True,\n            }\n        )\n        batch_size = T.int32(is_size_var=True)\n        vocab_size = T.int32(is_size_var=True)\n        num_seq = T.int32(is_size_var=True)\n        logits = T.match_buffer(var_logits, (batch_size, vocab_size), \"float32\")\n        seq_ids = T.match_buffer(var_seq_ids, (num_seq,), \"int32\")\n        bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), \"int32\")\n\n        for fused_s_v_0 in T.thread_binding(0, (num_seq * vocab_size + tx - 1) // tx, \"blockIdx.x\"):\n            for fused_s_v_1 in T.thread_binding(0, tx, \"threadIdx.x\"):\n                with T.sblock(\"block\"):\n                    vs = T.axis.spatial(num_seq, (fused_s_v_0 * tx + fused_s_v_1) // vocab_size)\n                    vv = T.axis.spatial(vocab_size, (fused_s_v_0 * tx + fused_s_v_1) % vocab_size)\n                    T.where(fused_s_v_0 * tx + fused_s_v_1 < num_seq * vocab_size)\n                    logits[seq_ids[vs], vv] = T.if_then_else(\n                        (bitmask[seq_ids[vs], vv // 32] >> (vv % 32)) & 1 == 1,\n                        logits[seq_ids[vs], vv],\n                        T.min_value(\"float32\"),\n                    )\n\n    return _apply_bitmask_inplace\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/attach_sampler.py",
    "content": "\"\"\"The pass that attaches GPU sampler functions to the IRModule.\"\"\"\n\nfrom typing import Dict\n\nimport tvm\nfrom tvm import IRModule, relax, te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.script import tir as T\n\nfrom mlc_llm.op.batch_spec_verify import batch_spec_verify\nfrom mlc_llm.op.top_p_pivot import top_p_pivot, top_p_renorm\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"AttachGPUSamplingFunc\")\nclass AttachGPUSamplingFunc:  # pylint: disable=too-few-public-methods\n    \"\"\"Attach GPU sampling functions to IRModule.\"\"\"\n\n    def __init__(self, target: tvm.target.Target, variable_bounds: Dict[str, int]):\n        # Specifically for RWKV workloads, which contains -1 max_seq_len\n        max_batch_size = variable_bounds[\"batch_size\"]\n        self.variable_bounds = {\n            \"batch_size\": max_batch_size,\n            \"num_samples\": max_batch_size,\n            \"num_positions\": 6 * max_batch_size,\n        }\n        self.non_negative_var = [\"vocab_size\"]\n        self.target = target\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"Entrypoint\"\"\"\n        if str(self.target.kind) not in [\"cuda\", \"vulkan\", \"metal\", \"webgpu\"]:\n            # Only enable GPU sampling for CUDA, Vulkan, Metal, and WebGPU.\n            return mod\n\n        bb = relax.BlockBuilder(mod)\n        if str(self.target.kind) == \"webgpu\":\n            # Only attach functions that do not contain i8s for WebGPU\n            gv_names = [\n                gv.name_hint\n                for gv in [\n                    _attach_argsort_func(bb),\n                    _attach_sample_with_top_p(bb),\n                ]\n            ]\n        else:\n            gv_names = [\n                gv.name_hint\n                for gv in [\n                    _attach_multinomial_sampling_func(bb),\n                    _attach_argsort_func(bb),\n                    _attach_sample_with_top_p(bb),\n                    _attach_take_probs_func(bb),\n                    _attach_batch_verifier(bb),\n                    _attach_renormalize_by_top_p(bb, self.target),\n                ]\n            ]\n\n        mod = bb.finalize()\n        for gv_name in gv_names:\n            mod[gv_name] = (\n                mod[gv_name]\n                .with_attr(\"tir_var_upper_bound\", self.variable_bounds)\n                .with_attr(\"tir_non_negative_var\", self.non_negative_var)\n            )\n        return mod\n\n\ndef _attach_multinomial_sampling_func(bb: relax.BlockBuilder):\n    batch_size = tir.SizeVar(\"batch_size\", \"int64\")\n    num_samples = tir.SizeVar(\"num_samples\", \"int64\")\n    vocab_size = tir.SizeVar(\"vocab_size\", \"int64\")\n    probs = relax.Var(\"probs\", relax.TensorStructInfo((batch_size, vocab_size), \"float32\"))\n    uniform_samples = relax.Var(\n        \"uniform_samples\", relax.TensorStructInfo((num_samples,), \"float32\")\n    )\n    sample_indices = relax.Var(\"sample_indices\", relax.TensorStructInfo((num_samples,), \"int32\"))\n    with bb.function(\"multinomial_from_uniform\", [probs, uniform_samples, sample_indices]):\n        with bb.dataflow():\n            sample_shape = relax.ShapeExpr([num_samples, 1])\n            probs_tensor = nn.wrap_nested(probs, name=\"probs\")\n            uniform_samples_tensor = nn.wrap_nested(\n                relax.call_pure_packed(\n                    \"vm.builtin.reshape\",\n                    uniform_samples,\n                    sample_shape,\n                    sinfo_args=relax.TensorStructInfo(sample_shape, \"float32\"),\n                ),\n                name=\"uniform_samples\",\n            )\n            sample_indices_tensor = nn.wrap_nested(\n                relax.call_pure_packed(\n                    \"vm.builtin.reshape\",\n                    sample_indices,\n                    sample_shape,\n                    sinfo_args=relax.TensorStructInfo(sample_shape, \"int32\"),\n                ),\n                name=\"sample_indices\",\n            )\n            result_tensor = nn.multinomial_from_uniform(  # pylint:disable=too-many-function-args\n                probs_tensor,\n                uniform_samples_tensor,\n                sample_indices_tensor,\n                \"int32\",\n                name=\"nn_multinomial_from_uniform\",\n            )\n            result = bb.emit(\n                relax.call_pure_packed(\n                    \"vm.builtin.reshape\",\n                    result_tensor._expr,  # pylint: disable=protected-access\n                    sample_indices.struct_info.shape,  # pylint: disable=no-member\n                    sinfo_args=sample_indices.struct_info,  # pylint: disable=no-member\n                )\n            )\n            output = bb.emit_output(result)\n        gv = bb.emit_func_output(output)\n    return gv\n\n\ndef _attach_argsort_func(bb: relax.BlockBuilder):\n    batch_size = tir.SizeVar(\"batch_size\", \"int64\")\n    vocab_size = tir.SizeVar(\"vocab_size\", \"int64\")\n    probs = relax.Var(\"probs\", relax.TensorStructInfo((batch_size, vocab_size), \"float32\"))\n    with bb.function(\"argsort_probs\", [probs]):\n        with bb.dataflow():\n            sorted_indices = bb.emit(relax.op.argsort(probs, descending=True, dtype=\"int32\"))\n            sorted_values = bb.emit_te(\n                lambda unsorted_probs, sorted_indices: te.compute(\n                    (batch_size, vocab_size),\n                    lambda i, j: unsorted_probs[i, sorted_indices[i, j]],\n                    name=\"take_sorted_probs\",\n                ),\n                probs,\n                sorted_indices,\n                primfunc_name_hint=\"take_sorted_probs\",\n            )\n            output = bb.emit_output((sorted_values, sorted_indices))\n        gv = bb.emit_func_output(output)\n    return gv\n\n\n@T.prim_func\ndef full(var_result: T.handle, value: T.int32):\n    \"\"\"The filling function for top k.\"\"\"\n    batch_size = T.int32(is_size_var=True)\n    result = T.match_buffer(var_result, (batch_size, 1), \"int32\")\n    for i in T.serial(batch_size):\n        with T.sblock(\"block\"):\n            vi = T.axis.spatial(batch_size, i)\n            result[vi, 0] = value\n\n\ndef _attach_sample_with_top_p(bb: relax.BlockBuilder):  # pylint: disable=too-many-locals\n    batch_size = tir.SizeVar(\"batch_size\", \"int64\")\n    num_samples = tir.SizeVar(\"num_samples\", \"int64\")\n    vocab_size = tir.SizeVar(\"vocab_size\", \"int64\")\n    sorted_probs = relax.Var(\n        \"sorted_probs\", relax.TensorStructInfo((batch_size, vocab_size), \"float32\")\n    )\n    sorted_indices = relax.Var(\n        \"sorted_indices\", relax.TensorStructInfo((batch_size, vocab_size), \"int32\")\n    )\n    uniform_samples = relax.Var(\n        \"uniform_samples\", relax.TensorStructInfo((num_samples,), \"float32\")\n    )\n    sample_indices = relax.Var(\"sample_indices\", relax.TensorStructInfo((num_samples,), \"int32\"))\n    top_p = relax.Var(\"top_p\", relax.TensorStructInfo((batch_size,), \"float32\"))\n\n    with bb.function(\n        \"sample_with_top_p\",\n        [sorted_probs, sorted_indices, uniform_samples, sample_indices, top_p],\n    ):\n        with bb.dataflow():\n            sample_shape = relax.ShapeExpr([num_samples, 1])\n            top_p_shape = relax.ShapeExpr([batch_size, 1])\n            sorted_probs_tensor = nn.wrap_nested(sorted_probs, name=\"sorted_probs\")\n            sorted_indices_tensor = nn.wrap_nested(sorted_indices, name=\"sorted_indices\")\n            uniform_samples_tensor = nn.wrap_nested(\n                relax.call_pure_packed(\n                    \"vm.builtin.reshape\",\n                    uniform_samples,\n                    sample_shape,\n                    sinfo_args=relax.TensorStructInfo(sample_shape, \"float32\"),\n                ),\n                name=\"uniform_samples\",\n            )\n            sample_indices_tensor = nn.wrap_nested(\n                relax.call_pure_packed(\n                    \"vm.builtin.reshape\",\n                    sample_indices,\n                    sample_shape,\n                    sinfo_args=relax.TensorStructInfo(sample_shape, \"int32\"),\n                ),\n                name=\"sample_indices\",\n            )\n            top_p_tensor = nn.wrap_nested(\n                relax.call_pure_packed(\n                    \"vm.builtin.reshape\",\n                    top_p,\n                    top_p_shape,\n                    sinfo_args=relax.TensorStructInfo(top_p_shape, \"float32\"),\n                ),\n                name=\"sample_indices\",\n            )\n            top_k_tensor = nn.tensor_ir_op(\n                full,\n                name_hint=\"full\",\n                args=[vocab_size],\n                out=nn.Tensor.placeholder(\n                    [batch_size, 1],\n                    \"int32\",\n                ),\n            )\n\n            result_tensor = (\n                nn.sample_top_p_top_k_from_sorted_prob(  # pylint:disable=too-many-function-args\n                    sorted_probs_tensor,\n                    sorted_indices_tensor,\n                    top_p_tensor,\n                    top_k_tensor,\n                    uniform_samples_tensor,\n                    sample_indices_tensor,\n                )\n            )\n            result = bb.emit_output(\n                relax.call_pure_packed(\n                    \"vm.builtin.reshape\",\n                    result_tensor._expr,  # pylint: disable=protected-access\n                    sample_indices.struct_info.shape,  # pylint: disable=no-member\n                    sinfo_args=sample_indices.struct_info,  # pylint: disable=no-member\n                )\n            )\n        gv = bb.emit_func_output(result)\n    return gv\n\n\ndef _attach_renormalize_by_top_p(bb: relax.BlockBuilder, target: tvm.target.Target):\n    batch_size = tir.SizeVar(\"batch_size\", \"int64\")\n    vocab_size = tir.SizeVar(\"vocab_size\", \"int64\")\n    num_pivots = 3\n    probs = relax.Var(\"probs\", relax.TensorStructInfo((batch_size, vocab_size), \"float32\"))\n    top_p = relax.Var(\"top_p\", relax.TensorStructInfo((batch_size,), \"float32\"))\n    init_pivots = relax.Var(\n        \"init_pivots\", relax.TensorStructInfo((batch_size, num_pivots), \"float32\")\n    )\n    with bb.function(\"renormalize_by_top_p\", [probs, top_p, init_pivots]):\n        with bb.dataflow():\n            cutoff_output = bb.emit(\n                relax.call_tir(\n                    bb.add_func(top_p_pivot(num_pivots, target), \"top_p_pivot_cutoff\"),\n                    args=[probs, top_p, init_pivots],\n                    out_sinfo=[top_p.struct_info, top_p.struct_info],  # pylint: disable=no-member\n                )\n            )\n            final_pivot = cutoff_output[0]\n            renorm_sum = cutoff_output[1]\n            renormalized_probs = bb.emit_output(\n                relax.call_tir(\n                    bb.add_func(top_p_renorm(target), \"top_p_renorm_after_cutoff\"),\n                    args=[probs, final_pivot, renorm_sum],\n                    out_sinfo=probs.struct_info,  # pylint: disable=no-member\n                )\n            )\n        gv = bb.emit_func_output(renormalized_probs)\n    return gv\n\n\ndef _attach_take_probs_func(bb: relax.BlockBuilder):\n    batch_size = tir.SizeVar(\"batch_size\", \"int64\")\n    num_samples = tir.SizeVar(\"num_samples\", \"int64\")\n    num_positions = tir.SizeVar(\"num_positions\", \"int64\")\n    vocab_size = tir.SizeVar(\"vocab_size\", \"int64\")\n    unsorted_probs = relax.Var(\n        \"unsorted_probs\", relax.TensorStructInfo((batch_size, vocab_size), \"float32\")\n    )\n    sorted_indices = relax.Var(\n        \"sorted_indices\", relax.TensorStructInfo((batch_size, vocab_size), \"int32\")\n    )\n    sample_indices = relax.Var(\"sample_indices\", relax.TensorStructInfo((num_samples,), \"int32\"))\n    sampling_results = relax.Var(\"sampling_result\", relax.TensorStructInfo((num_samples,), \"int32\"))\n    top_prob_offsets = relax.Var(\n        \"lobprob_offsets\", relax.TensorStructInfo((num_positions,), \"int32\")\n    )\n\n    @T.prim_func\n    def sampler_take_probs_tir(  # pylint: disable=too-many-locals,too-many-arguments\n        var_unsorted_probs: T.handle,\n        var_sorted_indices: T.handle,\n        var_sample_indices: T.handle,\n        var_sampling_results: T.handle,\n        var_top_prob_offsets: T.handle,\n        var_sampled_values: T.handle,\n        var_top_prob_probs: T.handle,\n        var_top_prob_indices: T.handle,\n    ):\n        batch_size = T.int32(is_size_var=True)\n        num_samples = T.int32(is_size_var=True)\n        num_positions = T.int32(is_size_var=True)\n        vocab_size = T.int32(is_size_var=True)\n        unsorted_probs = T.match_buffer(var_unsorted_probs, (batch_size, vocab_size), \"float32\")\n        sorted_indices = T.match_buffer(var_sorted_indices, (batch_size, vocab_size), \"int32\")\n        sample_indices = T.match_buffer(var_sample_indices, (num_samples,), \"int32\")\n        sampling_results = T.match_buffer(var_sampling_results, (num_samples,), \"int32\")\n        top_prob_offsets = T.match_buffer(var_top_prob_offsets, (num_positions,), \"int32\")\n        sampled_values = T.match_buffer(var_sampled_values, (num_samples,), \"float32\")\n        top_prob_probs = T.match_buffer(var_top_prob_probs, (num_positions,), \"float32\")\n        top_prob_indices = T.match_buffer(var_top_prob_indices, (num_positions,), \"int32\")\n        for i in T.serial(num_positions + num_samples):\n            with T.sblock(\"block\"):\n                vi = T.axis.spatial(num_positions + num_samples, i)\n                if vi < num_positions:\n                    row = T.floordiv(top_prob_offsets[vi], vocab_size)\n                    col = T.floormod(top_prob_offsets[vi], vocab_size)\n                    top_prob_indices[vi] = sorted_indices[row, col]\n                    top_prob_probs[vi] = unsorted_probs[row, sorted_indices[row, col]]\n                else:\n                    vj: T.int32 = vi - num_positions\n                    sampled_values[vj] = unsorted_probs[sample_indices[vj], sampling_results[vj]]\n\n    args = [\n        unsorted_probs,\n        sorted_indices,\n        sample_indices,\n        sampling_results,\n        top_prob_offsets,\n    ]\n    with bb.function(\"sampler_take_probs\", args):\n        with bb.dataflow():\n            taken_probs_indices = bb.emit_output(\n                relax.call_tir(\n                    bb.add_func(sampler_take_probs_tir, \"sampler_take_probs_tir\"),\n                    args,\n                    out_sinfo=[\n                        relax.TensorStructInfo((num_samples,), \"float32\"),\n                        relax.TensorStructInfo((num_positions,), \"float32\"),\n                        relax.TensorStructInfo((num_positions,), \"int32\"),\n                    ],\n                )\n            )\n        gv = bb.emit_func_output(taken_probs_indices)\n    return gv\n\n\ndef _attach_batch_verifier(bb: relax.BlockBuilder):\n    num_nodes = tir.SizeVar(\"num_nodes\", \"int64\")\n    nbatch = tir.SizeVar(\"nbatch\", \"int64\")\n    vocab_size = tir.SizeVar(\"vocab_size\", \"int64\")\n    draft_probs = relax.Var(\n        \"draft_probs\", relax.TensorStructInfo((num_nodes, vocab_size), \"float32\")\n    )\n    draft_tokens = relax.Var(\"draft_tokens\", relax.TensorStructInfo((num_nodes,), \"int32\"))\n    model_probs = relax.Var(\n        \"model_probs\", relax.TensorStructInfo((num_nodes, vocab_size), \"float32\")\n    )\n    token_tree_first_child = relax.Var(\n        \"token_tree_first_child\", relax.TensorStructInfo((num_nodes,), \"int32\")\n    )\n    token_tree_next_sibling = relax.Var(\n        \"token_tree_next_sibling\", relax.TensorStructInfo((num_nodes,), \"int32\")\n    )\n    uniform_samples = relax.Var(\"uniform_samples\", relax.TensorStructInfo((num_nodes,), \"float32\"))\n    token_tree_parent_ptr = relax.Var(\n        \"token_tree_parent_ptr\", relax.TensorStructInfo((nbatch,), \"int32\")\n    )\n    args = [\n        draft_probs,\n        draft_tokens,\n        model_probs,\n        token_tree_first_child,\n        token_tree_next_sibling,\n        uniform_samples,\n        token_tree_parent_ptr,\n    ]\n    with bb.function(\"sampler_verify_draft_tokens\", args):\n        with bb.dataflow():\n            res = bb.emit_output(\n                relax.call_tir_inplace(\n                    bb.add_func(\n                        batch_spec_verify(vocab_size),\n                        \"batch_verify_on_gpu_single_kernel\",\n                    ),\n                    args,\n                    inplace_indices=[\n                        args.index(model_probs),\n                        args.index(token_tree_parent_ptr),\n                    ],\n                    out_sinfo=[\n                        model_probs.struct_info,  # pylint: disable=no-member\n                        token_tree_parent_ptr.struct_info,  # pylint: disable=no-member\n                    ],\n                )\n            )\n        gv = bb.emit_func_output(res)\n    return gv\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py",
    "content": "\"\"\"A compiler pass that attaches two-stage softmax with temperature.\"\"\"\n\nfrom typing import Any, Dict, Optional\n\nimport tvm\nfrom tvm import relax, tir\nfrom tvm.ir.module import IRModule\nfrom tvm.relax.expr_functor import PyExprMutator, mutator\nfrom tvm.script import tir as T\n\nfrom ..support.max_thread_check import get_max_num_threads_per_block\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"AttachSoftmaxWithTemperature\")\nclass AttachSoftmaxWithTemperature:  # pylint: disable=too-few-public-methods\n    \"\"\"Rewrites one-shot softmax into two-stage softmax.\"\"\"\n\n    def __init__(\n        self, target: tvm.target.Target, metadata: Optional[Dict[str, Any]] = None\n    ) -> None:\n        self.target = target\n        self.metadata = metadata\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"IRModule-level transformation\"\"\"\n        return _Rewriter(mod, self.target, self.metadata).transform()\n\n\n@mutator\nclass _Rewriter(PyExprMutator):  # pylint: disable=abstract-method\n    def __init__(\n        self,\n        mod: IRModule,\n        target: tvm.target.Target,\n        metadata: Optional[Dict[str, Any]] = None,\n    ) -> None:\n        super().__init__(mod)\n        self.mod = mod\n        self.target = target\n        self.metadata = metadata\n        self.chunk_size = 4096\n        self.active_vocab_size = self.metadata.get(\"active_vocab_size\") if self.metadata else None\n\n    def transform(self) -> IRModule:\n        \"\"\"Entry point\"\"\"\n        batch_size = tir.SizeVar(\"batch_size\", \"int64\")\n        vocab_size = tir.SizeVar(\"vocab_size\", \"int64\")\n        dtype = \"float32\"\n        logits = relax.Var(\"logits\", relax.TensorStructInfo([batch_size, 1, vocab_size], dtype))\n        temperature = relax.Var(\"temperature\", relax.TensorStructInfo([batch_size], dtype))\n        with self.builder_.function(\"softmax_with_temperature\", params=[logits, temperature]):\n            with self.builder_.dataflow():\n                output_struct_info = logits.struct_info  # pylint: disable=no-member\n                new_shape = relax.ShapeExpr([batch_size, vocab_size])\n                logits = relax.call_pure_packed(\n                    \"vm.builtin.reshape\",\n                    logits,\n                    new_shape,\n                    sinfo_args=relax.TensorStructInfo(new_shape, dtype),\n                )\n                f_chunk_lse, f_softmax_with_lse = _get_lse_and_softmax_func(\n                    self.target, self.chunk_size, self.active_vocab_size\n                )\n                chunked_result_struct_info = relax.TensorStructInfo(\n                    (batch_size, (vocab_size + self.chunk_size - 1) // self.chunk_size),\n                    \"float32\",\n                )\n                chunked_results = self.builder_.emit(\n                    relax.call_tir(\n                        self.builder_.add_func(f_chunk_lse, \"chunk_lse\"),\n                        args=[logits, temperature],\n                        out_sinfo=[\n                            chunked_result_struct_info,\n                            chunked_result_struct_info,\n                        ],\n                    )\n                )\n                chunked_sum = chunked_results[0]\n                chunked_max = chunked_results[1]\n                softmax = self.builder_.emit(\n                    relax.call_tir(\n                        self.builder_.add_func(f_softmax_with_lse, \"softmax_with_chunked_sum\"),\n                        args=[logits, temperature, chunked_sum, chunked_max],\n                        out_sinfo=logits.struct_info,\n                    )\n                )\n                softmax = self.builder_.emit_output(\n                    relax.call_pure_packed(\n                        \"vm.builtin.reshape\",\n                        softmax,\n                        output_struct_info.shape,\n                        sinfo_args=output_struct_info,\n                    )\n                )\n            self.builder_.emit_func_output(softmax)\n        return self.builder_.get()\n\n\ndef _get_lse_and_softmax_func(  # pylint: disable=too-many-locals,too-many-statements\n    target: tvm.target.Target, chunk_size: int, active_vocab_size: int\n):\n    # NOTE: A quick note on the softmax implementation.\n    # We once tried to multiply every element by log2e which can be computed\n    # potentially more efficiently on hardware.\n    # However, when the input values are large, multiplying by the factor of log2e\n    # causes numerical issue in float32 dtype.\n    # This leads to the softmax output not summing up to 1.\n    # For numerical stability, we removed the log2e factor and switched back\n    # to the standard log/exp computation.\n\n    # The kernels below handle both the cases of temperature=0 and temperature != 0.\n    # - When temperature is not 0, the first kernel computes the log-sum-exp of\n    # chunks (subtracted by the max value in chunk), and the max values of chunks.\n    # The second kernel merges the log-sum-exp with the maximum values.\n    # - When temperature is 0, the first kernel computes the max value and the counts\n    # of the max value. The second kernel merges the max and counts, and set the\n    # softmax of the maximum values to \"max_value / max_count\".\n\n    # pylint: disable=invalid-name\n    @T.prim_func\n    def chunk_lse(  # pylint: disable=too-many-locals\n        var_A: T.handle,\n        var_temperature: T.handle,\n        var_chunked_sum: T.handle,\n        var_chunked_max: T.handle,\n    ):\n        T.func_attr({\"tir.noalias\": T.bool(True)})\n        batch_size = T.int64(is_size_var=True)\n        vocab_size = T.int64(is_size_var=True)\n        num_chunks = T.int64(is_size_var=True)\n        A = T.match_buffer(var_A, (batch_size, vocab_size), dtype=\"float32\")\n        temperature = T.match_buffer(var_temperature, (batch_size,), dtype=\"float32\")\n        chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks), dtype=\"float32\")\n        chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks), dtype=\"float32\")\n        A_pad = T.sblock_alloc_buffer(\n            (batch_size, num_chunks, T.int64(chunk_size)), dtype=\"float32\"\n        )\n        temp_max = T.sblock_alloc_buffer((batch_size, num_chunks), dtype=\"float32\")\n        temp_sum = T.sblock_alloc_buffer((batch_size, num_chunks), dtype=\"float32\")\n\n        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)):\n            with T.sblock(\"pad\"):\n                v0, v1, v2 = T.axis.remap(\"SSS\", [l0, l1, l2])\n                A_pad[v0, v1, v2] = T.Select(\n                    v1 * T.int64(chunk_size) + v2\n                    < (active_vocab_size if active_vocab_size is not None else vocab_size),\n                    T.if_then_else(\n                        temperature[v0] > T.float32(1e-5),\n                        A[v0, v1 * T.int64(chunk_size) + v2] / temperature[v0],\n                        A[v0, v1 * T.int64(chunk_size) + v2],\n                    ),\n                    T.min_value(\"float32\"),\n                )\n        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)):\n            with T.sblock(\"max\"):\n                v0, v1, v2 = T.axis.remap(\"SSR\", [l0, l1, l2])\n                with T.init():\n                    temp_max[v0, v1] = T.min_value(\"float32\")\n                temp_max[v0, v1] = T.max(temp_max[v0, v1], A_pad[v0, v1, v2])\n        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)):\n            with T.sblock(\"sum_exp\"):\n                v0, v1, v2 = T.axis.remap(\"SSR\", [l0, l1, l2])\n                with T.init():\n                    temp_sum[v0, v1] = T.float32(0)\n                temp_sum[v0, v1] += T.if_then_else(\n                    v1 * T.int64(chunk_size) + v2\n                    < (active_vocab_size if active_vocab_size is not None else vocab_size),\n                    T.Select(\n                        temperature[v0] > T.float32(1e-5),\n                        T.exp(A_pad[v0, v1, v2] - temp_max[v0, v1]),\n                        T.cast(A_pad[v0, v1, v2] == temp_max[v0, v1], \"float32\"),\n                    ),\n                    T.float32(0),\n                )\n        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(1)):\n            with T.sblock(\"log\"):\n                v0, v1, v2 = T.axis.remap(\"SSS\", [l0, l1, l2])\n                chunked_sum[v0, v1] = T.Select(\n                    temperature[v0] > T.float32(1e-5),\n                    T.log(temp_sum[v0, v1]),\n                    temp_sum[v0, v1],\n                )\n                chunked_max[v0, v1] = temp_max[v0, v1]\n\n    @T.prim_func\n    def softmax_with_chunked_sum(\n        var_A: T.handle,\n        var_temperature: T.handle,\n        var_chunked_sum: T.handle,\n        var_chunked_max: T.handle,\n        var_softmax: T.handle,\n    ):\n        T.func_attr({\"tir.noalias\": T.bool(True), \"tir.is_scheduled\": 1})\n        batch_size = T.int64(is_size_var=True)\n        vocab_size = T.int64(is_size_var=True)\n        num_chunks = T.int64(is_size_var=True)\n        A = T.match_buffer(var_A, (batch_size, vocab_size), dtype=\"float32\")\n        temperature = T.match_buffer(var_temperature, (batch_size,), dtype=\"float32\")\n        chunked_sum = T.match_buffer(var_chunked_sum, (batch_size, num_chunks), dtype=\"float32\")\n        chunked_max = T.match_buffer(var_chunked_max, (batch_size, num_chunks), dtype=\"float32\")\n        softmax = T.match_buffer(var_softmax, (batch_size, vocab_size), dtype=\"float32\")\n        temp_max = T.sblock_alloc_buffer((batch_size,), dtype=\"float32\")\n        temp_sum = T.sblock_alloc_buffer((batch_size,), dtype=\"float32\")\n        for l0, l1 in T.grid(batch_size, num_chunks):\n            with T.sblock(\"max\"):\n                v0, v1 = T.axis.remap(\"SR\", [l0, l1])\n                with T.init():\n                    temp_max[v0] = T.min_value(\"float32\")\n                temp_max[v0] = T.max(temp_max[v0], chunked_max[v0, v1])\n        for l0, l1 in T.grid(batch_size, num_chunks):\n            with T.sblock(\"sum_exp\"):\n                v0, v1 = T.axis.remap(\"SR\", [l0, l1])\n                with T.init():\n                    temp_sum[v0] = T.float32(0)\n                temp_sum[v0] += T.Select(\n                    temperature[v0] > T.float32(1e-5),\n                    T.exp(chunked_sum[v0, v1] + chunked_max[v0, v1] - temp_max[v0]),\n                    T.cast(chunked_max[v0, v1] == temp_max[v0], \"float32\") * chunked_sum[v0, v1],\n                )\n        for l0, l1, l2 in T.grid(batch_size, num_chunks, T.int64(chunk_size)):\n            with T.sblock(\"log_pad\"):\n                v0, v1, v2 = T.axis.remap(\"SSS\", [l0, l1, l2])\n                if v1 * T.int64(chunk_size) + v2 < vocab_size:\n                    softmax[v0, v1 * T.int64(chunk_size) + v2] = T.Select(\n                        v1 * T.int64(chunk_size) + v2\n                        < (active_vocab_size if active_vocab_size is not None else vocab_size),\n                        T.if_then_else(\n                            temperature[v0] > T.float32(1e-5),\n                            T.exp(\n                                A[v0, v1 * T.int64(chunk_size) + v2] / temperature[v0]\n                                - (T.log(temp_sum[v0]) + temp_max[v0])\n                            ),\n                            T.cast(\n                                A[v0, v1 * T.int64(chunk_size) + v2] == temp_max[v0],\n                                \"float32\",\n                            )\n                            / temp_sum[v0],\n                        ),\n                        T.float32(0),\n                    )\n\n    sch = tvm.s_tir.Schedule(IRModule({\"softmax_with_chunked_sum\": softmax_with_chunked_sum}))\n\n    def apply_gpu_schedule(target, sch):\n        max_threads = get_max_num_threads_per_block(target)\n        TX = 32\n        TY = max_threads // TX\n        unroll_depth = 64\n        # pylint: enable=invalid-name\n\n        sch.work_on(\"softmax_with_chunked_sum\")\n        l0, l1, l2 = sch.get_loops(\"log_pad\")\n        bx = sch.fuse(l0, l1)\n        sch.bind(bx, \"blockIdx.x\")\n        unroll, ty, tx = sch.split(l2, [None, TY, TX])\n        sch.bind(ty, \"threadIdx.y\")\n        sch.bind(tx, \"threadIdx.x\")\n        sch.annotate(unroll, ann_key=\"pragma_auto_unroll_max_step\", ann_val=unroll_depth)\n        sch.annotate(unroll, ann_key=\"pragma_unroll_explicit\", ann_val=1)\n\n        for block_name in [\"sum_exp\", \"max\"]:\n            block = sch.get_sblock(block_name)\n            sch.set_scope(block, buffer_index=0, storage_scope=\"shared\")\n            sch.compute_at(block, bx)\n            r_loop = sch.get_loops(block)[-1]\n            r_loop, tx = sch.split(r_loop, [None, TX])\n            sch.reorder(tx, r_loop)\n            sch.bind(tx, \"threadIdx.x\")\n            sch.annotate(r_loop, ann_key=\"pragma_auto_unroll_max_step\", ann_val=unroll_depth)\n            sch.annotate(r_loop, ann_key=\"pragma_unroll_explicit\", ann_val=1)\n\n        return chunk_lse, sch.mod[\"softmax_with_chunked_sum\"]\n\n    if target.kind.name == \"llvm\":\n        return chunk_lse, sch.mod[\"softmax_with_chunked_sum\"]\n    return apply_gpu_schedule(target, sch)\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/attach_spec_decode_aux_funcs.py",
    "content": "\"\"\"The pass that attaches logit processor functions to the IRModule.\"\"\"\n\nimport tvm\nfrom tvm import IRModule, relax, tir\nfrom tvm.relax import BlockBuilder, TensorStructInfo\nfrom tvm.script import tir as T\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"AttachSpecDecodeAuxFuncs\")\nclass AttachSpecDecodeAuxFuncs:  # pylint: disable=too-few-public-methods\n    \"\"\"Attach logit processing TIR functions to IRModule.\"\"\"\n\n    tensor_parallel_shards: int\n\n    def __init__(self, tensor_parallel_shards: int):\n        self.tensor_parallel_shards = tensor_parallel_shards\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"Entrypoint\"\"\"\n        mod = mod.clone()\n        bb = BlockBuilder(mod)\n        bb.add_func(\n            _get_scatter_2d_inplace(dtype=\"float32\", global_symbol=\"scatter_probs\"),\n            \"scatter_probs\",\n        )\n        bb.add_func(\n            _get_gather_2d_inplace(dtype=\"float32\", global_symbol=\"gather_probs\"),\n            \"gather_probs\",\n        )\n        if \"prefill_to_last_hidden_states\" in mod:\n            hidden_states_struct_info = mod[\"prefill_to_last_hidden_states\"].ret_struct_info.fields[\n                0\n            ]  # pylint: disable=no-member\n            dtype = hidden_states_struct_info.dtype\n            _add_gather_hidden_states(bb, self.tensor_parallel_shards, dtype)\n            _add_scatter_hidden_states(bb, self.tensor_parallel_shards, dtype)\n        return bb.finalize()\n\n\ndef _get_scatter_2d_inplace(dtype: str, global_symbol: str):\n    @T.prim_func\n    def _scatter_2d(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):\n        T.func_attr({\"global_symbol\": global_symbol, \"tir.noalias\": True})\n        batch_size = T.int32(is_size_var=True)\n        m = T.int32(is_size_var=True)\n        n = T.int32(is_size_var=True)\n        src = T.match_buffer(var_src, (batch_size, n), dtype)\n        indices = T.match_buffer(var_indices, (batch_size,), \"int32\")\n        dst = T.match_buffer(var_dst, (m, n), dtype)\n        for b, j in T.grid(batch_size, n):\n            with T.sblock(\"scatter_2d\"):\n                vb, vj = T.axis.remap(\"SS\", [b, j])\n                dst[indices[vb], vj] = src[vb, vj]\n\n    return _scatter_2d\n\n\ndef _get_gather_2d_inplace(dtype: str, global_symbol: str):\n    @T.prim_func\n    def _gather_2d(var_src: T.handle, var_indices: T.handle, var_dst: T.handle):\n        T.func_attr({\"global_symbol\": global_symbol, \"tir.noalias\": True})\n        batch_size = T.int32(is_size_var=True)\n        m = T.int32(is_size_var=True)\n        n = T.int32(is_size_var=True)\n        src = T.match_buffer(var_src, (m, n), dtype)\n        indices = T.match_buffer(var_indices, (batch_size,), \"int32\")\n        dst = T.match_buffer(var_dst, (batch_size, n), dtype)\n        for b, j in T.grid(batch_size, n):\n            with T.sblock(\"gather_2d\"):\n                vb, vj = T.axis.remap(\"SS\", [b, j])\n                dst[vb, vj] = src[indices[vb], vj]\n\n    return _gather_2d\n\n\ndef _add_scatter_hidden_states(bb: BlockBuilder, tensor_parallel_shards: int, dtype: str):\n    batch_size = tir.SizeVar(\"batch_size\", \"int64\")\n    m = tir.SizeVar(\"m\", \"int64\")\n    n = tir.SizeVar(\"n\", \"int64\")\n    src = relax.Var(\"src\", struct_info=TensorStructInfo([batch_size, n], dtype))\n    indices = relax.Var(\"indices\", struct_info=TensorStructInfo([batch_size], \"int32\"))\n    dst = relax.Var(\"dst\", struct_info=TensorStructInfo([m, n], dtype))\n    with bb.function(\"scatter_hidden_states\", [src, indices, dst]):\n        with bb.dataflow():\n            if tensor_parallel_shards > 1:\n                indices = relax.op.ccl.broadcast_from_worker0(indices)\n            output = bb.emit_output(\n                relax.op.call_tir_inplace(\n                    bb.add_func(\n                        _get_scatter_2d_inplace(dtype, \"_scatter_hidden_states\"),\n                        \"_scatter_hidden_states\",\n                    ),\n                    [src, indices, dst],\n                    2,\n                    dst.struct_info,  # pylint: disable=no-member\n                )\n            )\n        gv = bb.emit_func_output(output)\n    return gv\n\n\ndef _add_gather_hidden_states(bb: BlockBuilder, tensor_parallel_shards: int, dtype: str):\n    batch_size = tir.SizeVar(\"batch_size\", \"int64\")\n    m = tir.SizeVar(\"m\", \"int64\")\n    n = tir.SizeVar(\"n\", \"int64\")\n    src = relax.Var(\"src\", struct_info=TensorStructInfo([m, n], dtype))\n    indices = relax.Var(\"indices\", struct_info=TensorStructInfo([batch_size], \"int32\"))\n    dst = relax.Var(\"dst\", struct_info=TensorStructInfo([batch_size, n], dtype))\n    with bb.function(\"gather_hidden_states\", [src, indices, dst]):\n        with bb.dataflow():\n            if tensor_parallel_shards > 1:\n                indices = relax.op.ccl.broadcast_from_worker0(indices)\n            output = bb.emit_output(\n                relax.op.call_tir_inplace(\n                    bb.add_func(\n                        _get_gather_2d_inplace(dtype, \"_gather_hidden_states\"),\n                        \"_gather_hidden_states\",\n                    ),\n                    [src, indices, dst],\n                    2,\n                    dst.struct_info,  # pylint: disable=no-member\n                )\n            )\n        gv = bb.emit_func_output(output)\n    return gv\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/attach_support_info.py",
    "content": "\"\"\"A couple of passes that simply supportive information onto the IRModule.\"\"\"\n\nfrom math import lcm\nfrom typing import Any, Dict, List\n\nimport tvm\nfrom tvm import IRModule, relax, tir\nfrom tvm.ir import Op\nfrom tvm.relax.expr_functor import PyExprVisitor, visitor\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"AttachVariableBounds\")\nclass AttachVariableBounds:  # pylint: disable=too-few-public-methods\n    \"\"\"Attach variable bounds to each Relax function, which primarily helps with memory planning.\"\"\"\n\n    def __init__(self, variable_bounds: Dict[str, int]):\n        # Specifically for RWKV workloads, which contains -1 max_seq_len\n        self.variable_bounds = {k: v for k, v in variable_bounds.items() if v > 0}\n        self.non_negative_var = [\"vocab_size\"]\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"Entrypoint\"\"\"\n        for g_var, func in mod.functions_items():\n            if isinstance(func, relax.Function):\n                mod[g_var] = func.with_attr(\"tir_var_upper_bound\", self.variable_bounds).with_attr(\n                    \"tir_non_negative_var\", self.non_negative_var\n                )\n        return mod\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"AttachAdditionalPrimFuncs\")\nclass AttachAdditionalPrimFuncs:  # pylint: disable=too-few-public-methods\n    \"\"\"Attach extra TIR PrimFuncs to the IRModule\"\"\"\n\n    def __init__(self, functions: Dict[str, tir.PrimFunc]):\n        self.functions = functions\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"Entrypoint\"\"\"\n        for func_name, func in self.functions.items():\n            mod[func_name] = func.with_attr(\"global_symbol\", func_name)\n        return mod\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"AttachMemoryPlanAttr\")\nclass AttachMemoryPlanAttr:  # pylint: disable=too-few-public-methods\n    \"\"\"Attach memory planning attribute for dynamic function output planning to Relax functions.\"\"\"\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"Entrypoint\"\"\"\n        for g_var, func in mod.functions_items():\n            if isinstance(func, relax.Function):\n                mod[g_var] = func.with_attr(\"relax.memory_plan_dynamic_func_output\", True)\n        return mod\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"AttachCUDAGraphCaptureHints\")\nclass AttachCUDAGraphSymbolicCaptureHints:  # pylint: disable=too-few-public-methods\n    \"\"\"Attach CUDA graph capture hints to the IRModule\"\"\"\n\n    def __init__(self, hints: Dict[str, List[str]]):\n        self.hints = hints\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"Entrypoint\"\"\"\n        for g_var, func in mod.functions_items():\n            func_name = g_var.name_hint\n            if isinstance(func, relax.Function):\n                if func_name in self.hints:\n                    mod[g_var] = func.with_attr(\n                        \"relax.rewrite_cuda_graph.capture_symbolic_vars\",\n                        self.hints[func_name],\n                    )\n\n        return mod\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"AttachPipelineParallelStages\")\nclass AttachPipelineParallelStages:  # pylint: disable=too-few-public-methods\n    \"\"\"Attach number of pipeline stages to relax functions.\"\"\"\n\n    def __init__(self, pipeline_parallel_shards: int):\n        self.pipeline_parallel_shards = pipeline_parallel_shards\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"Entrypoint\"\"\"\n        for g_var, func in mod.functions_items():\n            func_name = g_var.name_hint\n            if not isinstance(func, relax.Function) or func_name not in [\n                \"prefill\",\n                \"decode\",\n                \"prefill_to_last_hidden_states\",\n                \"decode_to_last_hidden_states\",\n                \"batch_prefill\",\n                \"batch_decode\",\n                \"batch_verify\",\n                \"batch_prefill_to_last_hidden_states\",\n                \"batch_decode_to_last_hidden_states\",\n                \"batch_verify_to_last_hidden_states\",\n            ]:\n                continue\n            mod[g_var] = func.with_attr(\"pipeline_parallel_stages\", self.pipeline_parallel_shards)\n\n        return mod\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"AttachSequenceLengthPaddingFactor\")\nclass AttachSequenceLengthPaddingFactor:  # pylint: disable=too-few-public-methods\n    \"\"\"Attach sequence length padding factor to the metadata\"\"\"\n\n    def __init__(self, target: tvm.target.Target, metadata: Dict[str, Any]):\n        self.target = target\n        self.metadata = metadata\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"Entrypoint\"\"\"\n\n        @visitor\n        class _Visitor(PyExprVisitor):  # pylint: disable=abstract-method\n            def __init__(self, target: tvm.target.Target) -> None:\n                self.padding_factor = 1\n                self.target = target\n                self._op_call_dps_packed = Op.get(\"relax.call_dps_packed\")\n\n            def run(self, mod: IRModule) -> int:\n                \"\"\"Entry point of the visitor.\"\"\"\n                # Right now we only need padding for CUDA SM100a architecture.\n                # When the target is SM100a and uses cutlass gemm function,\n                # the sequence length needs to be padded to multiple of 4.\n                if self.target.kind.name != \"cuda\" or self.target.attrs.get(\"arch\") != \"sm_100a\":\n                    return 1\n\n                for _, func in mod.functions_items():\n                    if isinstance(func, relax.Function):\n                        self.visit_expr(func)\n                return self.padding_factor\n\n            def visit_call_(self, call: relax.Call) -> None:  # pylint: disable=arguments-renamed\n                super().visit_call_(call)\n                if call.op != self._op_call_dps_packed:\n                    return\n                func_name = str(call.args[0].global_symbol)\n                if func_name in [\n                    \"cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn\",\n                    \"cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn\",\n                ]:\n                    # Find the minimum common multiple of padding factor and 4\n                    self.padding_factor = lcm(self.padding_factor, 4)\n\n        # self.metadata[\"sequence_length_padding\"] = True\n        padding_factor = _Visitor(self.target).run(mod)\n        if padding_factor > 1:\n            self.metadata[\"seqlen_padding_factor\"] = padding_factor\n        return mod\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/blas_dispatch.py",
    "content": "\"\"\"A compiler pass that dispatches patterns to CUBLAS.\"\"\"\n\nimport tvm\nfrom tvm import IRModule, relax\nfrom tvm.relax.backend import get_patterns_with_prefix\n\ntry:\n    import tvm.relax.backend.cuda.cublas as _cublas\n    import tvm.relax.backend.rocm.hipblas as _hipblas\nexcept ImportError:\n    # Note: legacy path of cublas/hipblas for backward compatibility\n    import tvm.relax.backend.contrib.cublas as _cublas\n    import tvm.relax.backend.contrib.hipblas as _hipblas\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"BLASDispatch\")\nclass BLASDispatch:  # pylint: disable=too-few-public-methods,broad-exception-raised\n    \"\"\"A compiler pass that dispatches patterns to cuBLAS/hipBLAS.\"\"\"\n\n    def __init__(self, target: tvm.target.Target) -> None:\n        if target.kind.name == \"cuda\":\n            self.has_blas = tvm.get_global_func(\"relax.ext.cublas\", True)\n            if not self.has_blas:\n                raise Exception(\"cuBLAS is not enabled.\")\n            self.patterns = get_patterns_with_prefix(\"cublas\")\n        elif target.kind.name == \"rocm\":\n            self.has_blas = tvm.get_global_func(\"relax.ext.hipblas\", True)\n            if not self.has_blas:\n                raise Exception(\"hipBLAS is not enabled.\")\n            self.patterns = get_patterns_with_prefix(\"hipblas\")\n        else:\n            raise Exception(f\"Unsupported target {target.kind.name} for BLAS dispatch.\")\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"IRModule-level transformation\"\"\"\n        model_names = [\n            gv.name_hint for gv, func in mod.functions.items() if isinstance(func, relax.Function)\n        ]\n        # exclude single batch decode\n        model_names = [name for name in model_names if \"batch\" in name or \"decode\" not in name]\n        mod = tvm.transform.Sequential(\n            [\n                relax.transform.FuseOpsByPattern(\n                    self.patterns,\n                    bind_constants=False,\n                    annotate_codegen=True,\n                    entry_functions=model_names,\n                ),\n                relax.transform.RunCodegen({}, entry_functions=model_names),\n            ]\n        )(mod)\n        return mod\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/clean_up_tir_attrs.py",
    "content": "\"\"\"A compiler pass that cleans up undesired TIR attrs.\"\"\"\n\nfrom typing import List\n\nimport tvm\nfrom tvm.ir.module import IRModule\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"CleanUpTIRAttrs\")\nclass CleanUpTIRAttrs:  # pylint: disable=too-few-public-methods\n    \"\"\"A compiler pass that cleans up undesired TIR attrs.\"\"\"\n\n    def __init__(self, attrs: List[str]):\n        self.attrs = attrs\n\n    def transform_module(\n        self,\n        mod: IRModule,\n        _ctx: tvm.transform.PassContext,\n    ) -> IRModule:\n        \"\"\"IRModule-level transformation\"\"\"\n        for g_var, func in mod.functions_items():\n            changed = False\n            for attr in self.attrs:\n                if func.attrs is not None and attr in func.attrs:\n                    func = func.without_attr(attr)\n                    changed = True\n                    break\n            if changed:\n                mod[g_var] = func\n        return mod\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/dispatch_kv_cache_creation.py",
    "content": "\"\"\"A pass that rewrites KV cache creation functions in IRModule.\"\"\"\n\nimport json\nfrom typing import Any, Dict, List\n\nimport tvm\nfrom tvm import IRModule, relax\nfrom tvm.relax.frontend.nn.llm import kv_cache\nfrom tvm.relax.frontend.nn.llm.kv_cache import RopeMode\n\nfrom mlc_llm.support import logging\n\nlogger = logging.getLogger(__name__)\n\n\ndef extract_creation_args(func: relax.Function) -> Dict[str, Any]:\n    \"\"\"Extract the KV cache creation args from the given generic creation func.\"\"\"\n    assert isinstance(func.body, relax.SeqExpr)\n    assert len(func.body.blocks) == 1\n    assert isinstance(func.body.blocks[0], relax.DataflowBlock)\n    assert isinstance(func.body.blocks[0].bindings[0], relax.VarBinding)\n    assert isinstance(func.body.blocks[0].bindings[0].value, relax.Call)\n    assert func.body.blocks[0].bindings[0].value.op == tvm.ir.Op.get(\"relax.call_pure_packed\")\n    call_args = func.body.blocks[0].bindings[0].value.args\n    assert isinstance(call_args[0], relax.ExternFunc)\n    assert call_args[0].global_symbol == \"mlc.create_paged_kv_cache_generic\"\n    args = call_args[1:]\n    assert len(args) == 18\n    assert isinstance(args[0], (relax.StringImm, relax.Tuple))\n    # Check if attn_kind is a single value or a list with length of hidden layers\n    if isinstance(args[0], relax.StringImm):\n        assert args[0].value in [\"mha\", \"mla\"]\n        attn_kind = args[0].value\n    else:\n        assert len(args[0].fields) == args[3].value.value\n        for i, attention_type in enumerate(args[0].fields):\n            assert isinstance(attention_type, relax.StringImm)\n            assert attention_type.value in [\"mha\", \"mla\", \"mha_sliding\"]\n        attn_kind = [args[0].fields[i].value for i in range(len(args[0]))]\n    assert isinstance(args[1], relax.ShapeExpr)\n    assert len(args[1].values) == 5\n    assert isinstance(args[2], relax.ShapeExpr)\n    for i in range(3, 18):\n        if i in [13, 14, 17]:\n            continue\n        assert isinstance(args[i], relax.PrimValue), f\"args[{i}] is {type(args[i])}\"\n        assert isinstance(args[i].value, (tvm.tir.IntImm, tvm.tir.FloatImm))\n    assert isinstance(args[13], relax.StringImm)\n    assert isinstance(args[16], (relax.Constant, relax.PrimValue))\n    assert isinstance(args[17], relax.DataTypeImm)\n\n    return {\n        \"attn_kind\": attn_kind,\n        \"max_batch_size\": args[1].values[0],\n        \"max_total_seq_len\": args[1].values[1],\n        \"prefill_chunk_size\": args[1].values[2],\n        \"page_size\": args[1].values[3],\n        \"support_sliding_window\": args[1].values[4],\n        \"layer_partition\": args[2],\n        \"num_hidden_layers\": args[3].value.value,\n        \"num_attention_heads\": args[4].value.value,\n        \"num_key_value_heads\": args[5].value.value,\n        \"qk_head_dim\": args[6].value.value,\n        \"v_head_dim\": args[7].value.value,\n        \"mla_original_qk_head_dim\": args[8].value.value,\n        \"mla_original_v_head_dim\": args[9].value.value,\n        \"rope_mode\": args[10].value.value,\n        \"rope_scale\": args[11].value.value,\n        \"rope_theta\": args[12].value.value,\n        \"rope_scaling\": json.loads(args[13].value),\n        \"rope_ext_factors\": args[14],\n        \"rotary_dim\": args[15].value.value,\n        \"enable_disaggregation\": bool(args[16].value.value),\n        \"dtype\": args[17].value,\n    }\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"DispatchKVCacheCreation\")\nclass DispatchKVCacheCreation:  # pylint: disable=too-many-instance-attributes\n    \"\"\"Rewrite KV cache creation functions to IRModule.\"\"\"\n\n    def __init__(\n        self, target: tvm.target.Target, flashinfer: bool, metadata: Dict[str, Any]\n    ) -> None:\n        \"\"\"Initializer.\n\n        Parameters\n        ----------\n        target : tvm.target.Target\n            The target of the model compilation.\n\n        flashinfer : bool\n            A boolean indicating if flashinfer is enabled.\n\n        metadata : Dict[str, Any]\n            The model's metadata for KV cache creation.\n            Note that the metadata will be updated in this pass -- the\n            KV cache metadata will be attached.\n        \"\"\"\n        self.target = target\n        self.flashinfer = flashinfer\n        self.metadata = metadata\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"Entrypoint\"\"\"\n        func_dict = {}\n        creation_func = None\n        for g_var, func in mod.functions_items():\n            # Try to find the `create_paged_kv_cache` func.\n            if g_var.name_hint == \"create_paged_kv_cache\":\n                creation_func = func\n            else:\n                func_dict[g_var] = func\n\n        if creation_func is None:\n            return mod\n\n        new_mod = IRModule(func_dict)\n        if mod.attrs is not None:\n            new_mod = new_mod.with_attrs(mod.attrs)\n\n        kwargs = extract_creation_args(creation_func)\n        self.attach_kv_cache_metadata(kwargs)\n\n        bb = relax.BlockBuilder(new_mod)\n        extern_mods = []\n        extern_mods += self.create_tir_paged_kv_cache(bb, kwargs)\n        extern_mods += self.create_flashinfer_paged_kv_cache(bb, kwargs)\n\n        mod = bb.finalize()\n        mod_attrs = dict(mod.attrs) if mod.attrs else {}\n        mod = mod.with_attr(\"external_mods\", mod_attrs.get(\"external_mods\", []) + extern_mods)\n        return mod\n\n    def attach_kv_cache_metadata(self, kwargs: Dict[str, Any]):\n        \"\"\"Attach the KV cache metadata to model metadata.\"\"\"\n        self.metadata[\"kv_cache\"] = {\n            \"num_hidden_layers\": kwargs[\"num_hidden_layers\"],\n            \"num_attention_heads\": kwargs[\"num_attention_heads\"],\n            \"num_key_value_heads\": kwargs[\"num_key_value_heads\"],\n            \"head_dim\": kwargs[\"qk_head_dim\"],\n        }\n\n    def create_tir_paged_kv_cache(\n        self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]\n    ) -> List[tvm.runtime.Module]:\n        \"\"\"Create the TIR-based PagedKVCache\"\"\"\n        max_batch_size = relax.Var(\n            \"max_batch_size_\", relax.ShapeStructInfo([kwargs[\"max_batch_size\"]])\n        )\n        max_total_seq_len = relax.Var(\n            \"max_total_seq_len_\", relax.ShapeStructInfo([kwargs[\"max_total_seq_len\"]])\n        )\n        prefill_chunk_size = relax.Var(\n            \"prefill_chunk_size_\", relax.ShapeStructInfo([kwargs[\"prefill_chunk_size\"]])\n        )\n        page_size = relax.Var(\"page_size_\", relax.ShapeStructInfo([kwargs[\"page_size\"]]))\n        support_sliding_window = relax.Var(\n            \"support_sliding_window_\",\n            relax.ShapeStructInfo([kwargs[\"support_sliding_window\"]]),\n        )\n\n        # Ensure 'enable_disaggregation' is optional\n        enable_disaggregation = kwargs.pop(\"enable_disaggregation\", False)\n        kwargs[\"enable_disaggregation\"] = enable_disaggregation\n\n        with bb.function(\n            name=\"create_tir_paged_kv_cache\",\n            params=[\n                max_batch_size,\n                max_total_seq_len,\n                prefill_chunk_size,\n                page_size,\n                support_sliding_window,\n            ],\n        ):\n            cache = kv_cache.TIRPagedKVCache(target=self.target, **kwargs)\n            bb.emit_func_output(cache._expr)  # pylint: disable=protected-access\n\n        return cache.extern_mods\n\n    def create_flashinfer_paged_kv_cache(\n        self, bb: relax.BlockBuilder, kwargs: Dict[str, Any]\n    ) -> List[tvm.runtime.Module]:\n        \"\"\"Create the FlashInfer-based PagedKVCache\"\"\"\n        # Filter the cases which FlashInfer does not support.\n        if (  # pylint: disable=too-many-boolean-expressions\n            not self.flashinfer\n            or self.target.kind.name != \"cuda\"\n            or str(kwargs[\"dtype\"]) not in [\"float16\", \"bfloat16\"]\n            or (\n                kwargs[\"rope_mode\"] == RopeMode.INLINE\n                and (\n                    kwargs[\"rotary_dim\"] != kwargs[\"qk_head_dim\"]\n                    or kwargs[\"qk_head_dim\"] != kwargs[\"v_head_dim\"]\n                )\n            )\n        ):\n            return []\n\n        max_batch_size = relax.Var(\n            \"max_batch_size_\", relax.ShapeStructInfo([kwargs[\"max_batch_size\"]])\n        )\n        max_total_seq_len = relax.Var(\n            \"max_total_seq_len_\", relax.ShapeStructInfo([kwargs[\"max_total_seq_len\"]])\n        )\n        prefill_chunk_size = relax.Var(\n            \"prefill_chunk_size_\", relax.ShapeStructInfo([kwargs[\"prefill_chunk_size\"]])\n        )\n        page_size = relax.Var(\"page_size_\", relax.ShapeStructInfo([kwargs[\"page_size\"]]))\n        support_sliding_window = relax.Var(\n            \"support_sliding_window_\",\n            relax.ShapeStructInfo([kwargs[\"support_sliding_window\"]]),\n        )\n\n        try:\n            with bb.function(\n                name=\"create_flashinfer_paged_kv_cache\",\n                params=[\n                    max_batch_size,\n                    max_total_seq_len,\n                    prefill_chunk_size,\n                    page_size,\n                    support_sliding_window,\n                ],\n            ):\n                cache = kv_cache.FlashInferPagedKVCache(target=self.target, **kwargs)\n                bb.emit_func_output(cache._expr)  # pylint: disable=protected-access\n        except Exception as e:  # pylint: disable=broad-exception-caught\n            logger.info(\n                \"Error caught when creating FlashInfer PagedKVCache: %s\\n\"\n                \"The model will fallback to TIR-based KV cache.\",\n                e,\n            )\n            return []\n\n        return cache.extern_mods\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/dispatch_triton_kernel.py",
    "content": "\"\"\"A pass that dispatch generic calls of triton kernels to specific kernel implementations.\"\"\"\n\n# pylint: disable=invalid-name\n\nfrom typing import List\n\nimport tvm\nfrom tvm import IRModule, relax\nfrom tvm.relax.expr_functor import PyExprMutator, mutator\n\nfrom mlc_llm.op.triton import (\n    get_tir_w8a8_block_fp8_group_matmul,\n    get_tir_w8a8_block_fp8_matmul,\n)\nfrom mlc_llm.support import logging\n\nlogger = logging.getLogger(__name__)\n\n\n@mutator\nclass _Rewriter(PyExprMutator):  # pylint: disable=abstract-method\n    def __init__(self, mod: IRModule, target: tvm.target.Target) -> None:\n        super().__init__(mod)\n        self.mod = mod\n        self.target = target\n        self.extern_mods: List[tvm.runtime.Module] = []\n\n    def transform(self) -> tvm.IRModule:  # pylint: disable=too-many-locals\n        \"\"\"Entry point of the transformation\"\"\"\n        for g_var, func in self.mod.functions_items():\n            if not isinstance(func, relax.Function):\n                continue\n            new_func = self.visit_expr(func)\n            # new_func = remove_all_unused(new_func)\n            self.builder_.update_func(g_var, new_func)\n\n        mod = self.builder_.finalize()\n        mod_attrs = dict(mod.attrs) if mod.attrs else {}\n        mod = mod.with_attr(\n            \"external_mods\", list(mod_attrs.get(\"external_mods\", [])) + self.extern_mods\n        )\n        return mod\n\n    def visit_call_(self, call: relax.Call) -> relax.Expr:  # pylint: disable=arguments-renamed\n        call = super().visit_call_(call)\n\n        if (\n            call.op != tvm.ir.Op.get(\"relax.call_dps_packed\")\n            or not isinstance(call.args[0], relax.ExternFunc)\n            or not str(call.args[0].global_symbol).startswith(\"mlc.triton.\")\n        ):\n            return call\n\n        global_symbol = str(call.args[0].global_symbol)\n        assert isinstance(call.args[1], relax.Tuple)\n        if global_symbol == \"mlc.triton.w8a8_block_fp8_matmul\":\n            return self.w8a8_block_fp8_matmul(call.args[1].fields, call.struct_info)\n        if global_symbol == \"mlc.triton.w8a8_block_fp8_group_matmul\":\n            return self.w8a8_block_fp8_group_matmul(call.args[1].fields, call.struct_info)\n        raise ValueError(f\"Unknown mlc.triton kernel identifier: {global_symbol}\")\n\n    def w8a8_block_fp8_matmul(  # pylint: disable=too-many-locals\n        self, args: List[relax.Expr], out_sinfo: relax.StructInfo\n    ) -> relax.Expr:\n        \"\"\"Emit the w8a8_block_fp8_matmul triton kernel.\"\"\"\n        assert len(args) == 16\n        x, weight, x_scale, weight_scale = args[:4]\n        (\n            N,\n            K,\n            block_n,\n            block_k,\n            BLOCK_SIZE_M,\n            BLOCK_SIZE_N,\n            BLOCK_SIZE_K,\n            GROUP_SIZE_M,\n            num_warps,\n            num_stages,\n        ) = [arg.value.value for arg in args[4:14]]\n        in_dtype, out_dtype = str(args[14].value), str(args[15].value)\n\n        prim_func, func_name = get_tir_w8a8_block_fp8_matmul(\n            N,\n            K,\n            block_n,\n            block_k,\n            in_dtype,  # type: ignore\n            out_dtype,  # type: ignore\n            BLOCK_SIZE_M,\n            BLOCK_SIZE_N,\n            BLOCK_SIZE_K,\n            GROUP_SIZE_M,\n            num_warps,\n            num_stages,\n            self.extern_mods,\n        )\n        if prim_func is None:\n            # The TIR function is already in the IRModule\n            gv = self.builder_.get().get_global_var(func_name)\n        else:\n            # Add the TIR function to the IRModule\n            gv = self.builder_.add_func(prim_func, func_name)\n\n        return relax.call_tir(gv, [x, weight, x_scale, weight_scale], out_sinfo=out_sinfo)\n\n    def w8a8_block_fp8_group_matmul(  # pylint: disable=too-many-locals\n        self, args: List[relax.Expr], out_sinfo: relax.StructInfo\n    ) -> relax.Expr:\n        \"\"\"Emit the w8a8_block_fp8_group_matmul triton kernel.\"\"\"\n        assert len(args) == 19\n        x, weight, x_scale, weight_scale, expert_ids, indptr = args[:6]\n        (\n            N,\n            K,\n            num_experts,\n            block_n,\n            block_k,\n            BLOCK_SIZE_M,\n            BLOCK_SIZE_N,\n            BLOCK_SIZE_K,\n            GROUP_SIZE_M,\n            num_warps,\n            num_stages,\n        ) = [arg.value.value for arg in args[6:17]]\n        in_dtype, out_dtype = str(args[17].value), str(args[18].value)\n\n        prim_func, func_name = get_tir_w8a8_block_fp8_group_matmul(\n            N,\n            K,\n            num_experts,\n            block_n,\n            block_k,\n            in_dtype,  # type: ignore\n            out_dtype,  # type: ignore\n            BLOCK_SIZE_M,\n            BLOCK_SIZE_N,\n            BLOCK_SIZE_K,\n            GROUP_SIZE_M,\n            num_warps,\n            num_stages,\n            self.extern_mods,\n        )\n        if prim_func is None:\n            # The TIR function is already in the IRModule\n            gv = self.builder_.get().get_global_var(func_name)\n        else:\n            # Add the TIR function to the IRModule\n            gv = self.builder_.add_func(prim_func, func_name)\n\n        return relax.call_tir(\n            gv,\n            [x, weight, x_scale, weight_scale, expert_ids, indptr],\n            out_sinfo=out_sinfo,\n        )\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"DispatchTritonKernel\")\nclass DispatchTritonKernel:  # pylint: disable=too-many-instance-attributes,too-few-public-methods\n    \"\"\"Rewrite KV cache creation functions to IRModule.\"\"\"\n\n    def __init__(self, target: tvm.target.Target) -> None:\n        \"\"\"Initializer.\n\n        Parameters\n        ----------\n        \"\"\"\n        self.target = target\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"Entrypoint\"\"\"\n        if self.target.kind.name != \"cuda\":\n            return mod\n\n        return _Rewriter(mod, self.target).transform()\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/estimate_memory_usage.py",
    "content": "\"\"\"Memory usage estimation analysis function for Relax functions.\"\"\"\n\nimport json\nfrom typing import Any, Dict\n\nimport tvm\nfrom tvm import relax, tir\nfrom tvm.ir import IRModule, Op\nfrom tvm.relax.expr_functor import PyExprVisitor, visitor\n\nfrom mlc_llm.support import logging\n\nlogger = logging.getLogger(__name__)\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"AttachMetadata\")\nclass AttachMetadataWithMemoryUsage:  # pylint: disable=too-few-public-methods\n    \"\"\"Attach a Relax function that returns metadata in a JSON string\"\"\"\n\n    def __init__(self, metadata: Dict[str, Any]):\n        self.metadata = metadata\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"Entrypoint\"\"\"\n\n        func_name = \"_metadata\"\n\n        def _emit_metadata(metadata):\n            bb = relax.BlockBuilder()  # pylint: disable=invalid-name\n            with bb.function(func_name, params=[]):\n                bb.emit_func_output(relax.StringImm(json.dumps(metadata)))\n            return bb.finalize()[func_name]\n\n        self.metadata[\"memory_usage\"] = _MemoryEstimator().run(mod)\n        mod[func_name] = _emit_metadata(self.metadata)\n        return mod\n\n\n@visitor\nclass _MemoryEstimator(PyExprVisitor):\n    \"\"\"The IR visitor which estimates the memory usage of each Relax function.\"\"\"\n\n    def __init__(self) -> None:\n        self.planned_alloc_mem = 0\n        self.planned_mem_num = 0\n        self._op_alloc_tensor = Op.get(\"relax.builtin.alloc_tensor\")\n        self._op_alloc_storage = Op.get(\"relax.memory.alloc_storage\")\n\n    def run(self, mod: IRModule) -> Dict[str, int]:\n        \"\"\"Entry point of the visitor.\"\"\"\n        result: Dict[str, int] = {}\n        for global_var, func in mod.functions_items():\n            if isinstance(func, relax.Function):\n                self.planned_alloc_mem = 0\n                self.planned_mem_num = 0\n                self.visit_expr(func)\n                result[global_var.name_hint] = self.planned_alloc_mem\n                logger.info(\n                    \"[Memory usage] Function `%s`: %.2f MB\",\n                    global_var.name_hint,\n                    self.planned_alloc_mem / 1024 / 1024,\n                )\n        return result\n\n    def visit_call_(self, call: relax.Call) -> None:  # pylint: disable=arguments-renamed\n        if call.op == self._op_alloc_tensor:\n            self._builtin_tensor_alloc(shape=call.args[0], dtype_str=call.args[1].value)\n        elif call.op == self._op_alloc_storage:\n            self._storage_alloc(size=call.args[0])\n        super().visit_call_(call)\n\n    def _builtin_tensor_alloc(self, shape: relax.Expr, dtype_str: str) -> None:\n        assert isinstance(shape, relax.ShapeExpr)\n        size = 1\n        for dim_len in shape.values:\n            if not isinstance(dim_len, tvm.tir.IntImm):\n                return\n            size *= dim_len.value\n        dtype = tvm.DataType(dtype_str)\n        self.planned_mem_num += 1\n        self.planned_alloc_mem += size * ((dtype.bits + 7) // 8) * dtype.lanes\n\n    def _storage_alloc(self, size: relax.Expr) -> None:\n        assert isinstance(size, relax.ShapeExpr)\n        if isinstance(size.values[0], tir.IntImm):\n            self.planned_mem_num += 1\n            self.planned_alloc_mem += size.values[0].value\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/fuse_add_norm.py",
    "content": "\"\"\"A compiler pass that fuses add + rms_norm.\"\"\"\n\n# pylint: disable=invalid-name\n\nfrom typing import Optional\n\nimport tvm\nfrom tvm import relax\nfrom tvm.relax.analysis import remove_all_unused\nfrom tvm.relax.expr_functor import PyExprMutator, mutator\nfrom tvm.script import tir as T\n\nfrom ..support.max_thread_check import get_max_num_threads_per_block\n\n\ndef _get_add_rms_norm_decode(hidden_size: int, eps: float, TX: int, in_dtype: str):\n    if in_dtype not in (\"float16\", \"bfloat16\"):\n        raise ValueError(f\"Unsupported data type: {in_dtype}\")\n    inv_hidden_size = T.float32(1.0 / float(hidden_size))\n    eps = T.float32(eps)\n    add_local_size = hidden_size // TX\n\n    @T.prim_func(private=True)\n    def decode_add_rms(  # pylint: disable=too-many-locals\n        pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: T.handle\n    ):\n        T.func_attr({\"tir.noalias\": T.bool(True), \"tir.is_scheduled\": 1})\n        batch_size = T.int32()\n        A = T.match_buffer(pA, (batch_size, 1, hidden_size), in_dtype)\n        B = T.match_buffer(pB, (batch_size, 1, hidden_size), in_dtype)\n        C = T.match_buffer(pC, (hidden_size,), in_dtype)\n        O = T.match_buffer(pO, (batch_size, 1, hidden_size), in_dtype)\n        add = T.match_buffer(pAdd, (batch_size, 1, hidden_size), in_dtype)\n        add_local = T.sblock_alloc_buffer((hidden_size // TX,), in_dtype, scope=\"local\")\n        sum_shared = T.sblock_alloc_buffer((batch_size, 1), scope=\"shared\")\n        sum_local = T.sblock_alloc_buffer((TX, batch_size, 1), scope=\"local\")\n        for v_bx in T.thread_binding(batch_size, thread=\"blockIdx.x\"):\n            for v_tx in T.thread_binding(\n                TX,\n                thread=\"threadIdx.x\",\n                annotations={\n                    \"pragma_auto_unroll_max_step\": 256,\n                    \"pragma_unroll_explicit\": 1,\n                },\n            ):\n                for i in range(add_local_size):\n                    with T.sblock(\"T_add\"):\n                        bx = T.axis.spatial(batch_size, v_bx)\n                        h = T.axis.spatial(hidden_size, i * TX + v_tx)\n                        add_local[h // TX] = A[bx, 0, h] + B[bx, 0, h]\n                    with T.sblock(\"T_write_back\"):\n                        bx = T.axis.spatial(batch_size, v_bx)\n                        v_ax1 = T.axis.spatial(1, 0)\n                        h = T.axis.spatial(hidden_size, i * TX + v_tx)\n                        add[bx, v_ax1, h] = add_local[h // TX]\n                with T.sblock(\"T_multiply_red_rf_init\"):\n                    tx, bx = T.axis.remap(\"SS\", [v_tx, v_bx])\n                    sum_local[tx, bx, 0] = T.float32(0)\n                for v_i, _j in T.grid(add_local_size, 1):\n                    with T.sblock(\"T_multiply_red_rf_update\"):\n                        tx, bx, i = T.axis.remap(\"SSR\", [v_tx, v_bx, v_i])\n                        sum_local[tx, bx, 0] += T.float32(add_local[i]) * T.float32(add_local[i])\n            for _j in range(1):\n                for v_tx_2 in T.thread_binding(TX, thread=\"threadIdx.x\"):\n                    with T.sblock(\"T_multiply_red\"):\n                        tx, bx = T.axis.remap(\"RS\", [v_tx_2, v_bx])\n                        T.reads(sum_local[tx, bx, 0])\n                        T.writes(sum_shared[bx, 0])\n                        with T.init():\n                            sum_shared[bx, 0] = T.float32(0)\n                        sum_shared[bx, 0] += sum_local[tx, bx, 0]\n            for i in range(add_local_size):\n                for v_tx_2 in T.thread_binding(TX, thread=\"threadIdx.x\"):\n                    with T.sblock(\"T_cast_2\"):\n                        bx = T.axis.spatial(batch_size, v_bx)\n                        h = T.axis.spatial(hidden_size, i * TX + v_tx_2)\n                        O[bx, 0, h] = T.cast(\n                            T.rsqrt(sum_shared[bx, 0] * inv_hidden_size + eps)\n                            * T.float32(add_local[h // TX])\n                            * T.float32(C[h]),\n                            dtype=in_dtype,\n                        )\n\n    return decode_add_rms\n\n\ndef _get_add_rms_norm_prefill(hidden_size: int, eps: float, TX: int, in_dtype: str):\n    if in_dtype not in (\"float16\", \"bfloat16\"):\n        raise ValueError(f\"Unsupported data type: {in_dtype}\")\n    inv_hidden_size = T.float32(1.0 / float(hidden_size))\n    eps = T.float32(eps)\n    add_local_size = hidden_size // TX\n\n    @T.prim_func(private=True)\n    def prefill_add_rms(  # pylint: disable=too-many-locals\n        pA: T.handle, pB: T.handle, pC: T.handle, pO: T.handle, pAdd: T.handle\n    ):\n        T.func_attr({\"tir.noalias\": T.bool(True), \"tir.is_scheduled\": 1})\n        seq_len = T.int32()\n        A = T.match_buffer(pA, (1, seq_len, hidden_size), in_dtype)\n        B = T.match_buffer(pB, (1, seq_len, hidden_size), in_dtype)\n        C = T.match_buffer(pC, (hidden_size,), in_dtype)\n        O = T.match_buffer(pO, (1, seq_len, hidden_size), in_dtype)\n        add = T.match_buffer(pAdd, (1, seq_len, hidden_size), in_dtype)\n        add_local = T.sblock_alloc_buffer((hidden_size // TX,), in_dtype, scope=\"local\")\n        sum_shared = T.sblock_alloc_buffer((1, seq_len), scope=\"shared\")\n        sum_local = T.sblock_alloc_buffer((TX, 1, seq_len), scope=\"local\")\n        for v_bx in T.thread_binding(seq_len, thread=\"blockIdx.x\"):\n            for v_tx in T.thread_binding(\n                TX,\n                thread=\"threadIdx.x\",\n                annotations={\n                    \"pragma_auto_unroll_max_step\": 256,\n                    \"pragma_unroll_explicit\": 1,\n                },\n            ):\n                for v_i in range(add_local_size):\n                    with T.sblock(\"T_add\"):\n                        bx = T.axis.spatial(seq_len, v_bx)\n                        h = T.axis.spatial(hidden_size, v_i * TX + v_tx)\n                        add_local[h // TX] = A[0, bx, h] + B[0, bx, h]\n                    with T.sblock(\"T_write_back\"):\n                        bx = T.axis.spatial(seq_len, v_bx)\n                        h = T.axis.spatial(hidden_size, v_i * TX + v_tx)\n                        add[0, bx, h] = add_local[h // TX]\n                with T.sblock(\"T_multiply_red_rf_init\"):\n                    tx, bx = T.axis.remap(\"SS\", [v_tx, v_bx])\n                    sum_local[tx, 0, bx] = T.float32(0)\n                for v_i, _j in T.grid(add_local_size, 1):\n                    with T.sblock(\"T_multiply_red_rf_update\"):\n                        tx, bx, i = T.axis.remap(\"SSR\", [v_tx, v_bx, v_i])\n                        sum_local[tx, 0, bx] += T.float32(add_local[i]) * T.float32(add_local[i])\n            for _j in range(1):\n                for v_tx_2 in T.thread_binding(TX, thread=\"threadIdx.x\"):\n                    with T.sblock(\"T_multiply_red\"):\n                        tx, bx = T.axis.remap(\"RS\", [v_tx_2, v_bx])\n                        with T.init():\n                            sum_shared[0, bx] = T.float32(0)\n                        sum_shared[0, bx] = sum_shared[0, bx] + sum_local[tx, 0, bx]\n            for v_i in range(add_local_size):\n                for v_tx_2 in T.thread_binding(TX, thread=\"threadIdx.x\"):\n                    with T.sblock(\"T_cast_2\"):\n                        bx = T.axis.spatial(seq_len, v_bx)\n                        v1 = T.axis.spatial(hidden_size, v_i * TX + v_tx_2)\n                        O[0, bx, v1] = T.cast(\n                            T.rsqrt(sum_shared[0, bx] * inv_hidden_size + eps)\n                            * T.float32(add_local[v1 // TX])\n                            * T.float32(C[v1]),\n                            dtype=in_dtype,\n                        )\n\n    return prefill_add_rms\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"FuseAddRMSNorm\")\nclass FuseAddRMSNorm:  # pylint: disable=too-few-public-methods\n    \"\"\"A compiler pass that fuses add + rms_norm.\"\"\"\n\n    def __init__(self, target: tvm.target.Target) -> None:\n        \"\"\"Initializer.\n\n        Parameters\n        ----------\n        target : tvm.target.Target\n            Target device.\n        \"\"\"\n        self.target = target\n\n    def transform_module(self, mod: tvm.IRModule, _ctx: tvm.transform.PassContext) -> tvm.IRModule:\n        \"\"\"IRModule-level transformation.\"\"\"\n        return _FuseAddRMSNormRewriter(mod.clone(), self.target).transform()\n\n\n@mutator\nclass _FuseAddRMSNormRewriter(PyExprMutator):  # pylint: disable=abstract-method\n    def __init__(self, mod: tvm.IRModule, target: tvm.target.Target):\n        super().__init__(mod)\n        self.mod = mod\n        self.prefill_norm_gv: Optional[tvm.ir.GlobalVar] = None\n        self.decode_norm_gv: Optional[tvm.ir.GlobalVar] = None\n        self.TX = min(1024, get_max_num_threads_per_block(target))\n\n    def transform(self) -> tvm.IRModule:  # pylint: disable=too-many-locals\n        \"\"\"Entry point of the transformation\"\"\"\n        for g_var, func in self.mod.functions_items():\n            if not isinstance(func, relax.Function):\n                continue\n            new_func = self.visit_expr(func)\n            new_func = remove_all_unused(new_func)\n            self.builder_.update_func(g_var, new_func)\n        return self.builder_.finalize()\n\n    def visit_call_(self, call: relax.Call) -> relax.Expr:  # pylint: disable=arguments-renamed\n        call = super().visit_call_(call)\n\n        # Match the \"rms_norm(add(x1, x2), w)\" pattern\n        if call.op != tvm.ir.Op.get(\"relax.nn.rms_norm\") or call.struct_info.dtype not in [\n            \"bfloat16\",\n            \"float16\",\n        ]:\n            return call\n        assert len(call.args) == 2\n        weight = call.args[1]\n        eps = call.attrs.epsilon\n        assert isinstance(call.args[0], relax.Var)\n        y = self.lookup_binding(call.args[0])\n        if not isinstance(y, relax.Call) or y.op != tvm.ir.Op.get(\"relax.add\"):\n            return call\n        assert len(y.args) == 2\n        x1 = y.args[0]\n        x2 = y.args[1]\n        # Extra check\n        n, _, h = x1.struct_info.shape\n        h = int(h)\n        if h % self.TX != 0:\n            return call\n\n        is_prefill = n == 1\n        func_gv = self.prefill_norm_gv if is_prefill else self.decode_norm_gv\n        if func_gv is None:\n            if is_prefill:\n                func_gv = self.builder_.add_func(\n                    _get_add_rms_norm_prefill(h, eps, self.TX, call.struct_info.dtype),\n                    \"fuse_add_norm_prefill\",\n                )\n                self.prefill_norm_gv = func_gv\n            else:\n                func_gv = self.builder_.add_func(\n                    _get_add_rms_norm_decode(h, eps, self.TX, call.struct_info.dtype),\n                    \"fuse_add_norm_decode\",\n                )\n                self.decode_norm_gv = func_gv\n\n        tuple_output = self.builder_.emit(\n            relax.call_tir(\n                func_gv,\n                [x1, x2, weight],\n                out_sinfo=[x1.struct_info, x2.struct_info],\n            )\n        )\n        new_o = relax.TupleGetItem(tuple_output, 0)\n        new_y = self.builder_.emit(relax.TupleGetItem(tuple_output, 1))\n        self.set_var_remap(call.args[0].vid, new_y)\n        return new_o\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/fuse_dequantize_matmul_ewise.py",
    "content": "\"\"\"A compiler pass that fuses dequantize + matmul + elementwise.\"\"\"\n\nimport tvm\nfrom tvm import IRModule, relax\nfrom tvm.relax.dpl.pattern import GlobalVarPattern, TuplePattern, is_op, wildcard\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"FuseDequantizeMatmulEwise\")\nclass FuseDequantizeMatmulEwise:  # pylint: disable=too-few-public-methods\n    \"\"\"A compiler pass that fuses dequantize + matmul + elementwise.\"\"\"\n\n    def transform_module(\n        self,\n        mod: IRModule,\n        _ctx: tvm.transform.PassContext,\n    ) -> IRModule:\n        \"\"\"IRModule-level transformation\"\"\"\n        seq = []\n        for n_aux_tensor in [0, 1, 2, 3, 4]:\n            for match_ewise in [0, 1, 2, 3, 6]:\n                if match_ewise == 6 and n_aux_tensor != 4:\n                    continue\n                seq.append(\n                    relax.transform.FuseOpsByPattern(\n                        [\n                            (\n                                \"dequantize_matmul\",\n                                *_pattern(match_ewise, n_aux_tensor),\n                            )\n                        ]\n                    )\n                )\n        seq.append(relax.transform.FuseTIR())\n        return tvm.transform.Sequential(seq)(mod)\n\n\ndef _pattern(match_ewise: int, n_aux_tensor: int):\n    # pylint: disable=invalid-name\n    w_scaled = wildcard()\n    x = wildcard()\n    w = is_op(\"relax.call_tir\")(\n        GlobalVarPattern(),\n        TuplePattern([w_scaled] + [wildcard() for _ in range(n_aux_tensor)]),\n        add_constraint=False,\n    )\n    matmul = is_op(\"relax.call_tir\")(\n        GlobalVarPattern(),\n        TuplePattern([x, w] + [wildcard() for _ in range(match_ewise)]),\n        add_constraint=False,\n    )\n    # pylint: enable=invalid-name\n    annotations = {\n        \"w_scaled\": w_scaled,\n        \"x\": x,\n        \"w\": w,\n        \"matmul\": matmul,\n    }\n\n    def _check_decoding(ctx: relax.transform.PatternCheckContext) -> bool:\n        call = ctx.annotated_expr[\"w\"]\n        if not isinstance(call, relax.Call):\n            return False\n        g_var = call.args[0]\n        if not isinstance(g_var, relax.GlobalVar):\n            return False\n        return g_var.name_hint.startswith(\"dequantize\") or g_var.name_hint.startswith(\n            \"fused_dequantize\"\n        )\n\n    def _check_matmul(ctx: relax.transform.PatternCheckContext) -> bool:\n        call = ctx.annotated_expr[\"matmul\"]\n        if not isinstance(call, relax.Call):\n            return False\n        g_var = call.args[0]\n        if not isinstance(g_var, relax.GlobalVar):\n            return False\n        return (\n            g_var.name_hint.startswith(\"matmul\")\n            or g_var.name_hint.startswith(\"fused_matmul\")\n            or g_var.name_hint.startswith(\"NT_matmul\")\n            or g_var.name_hint.startswith(\"fused_NT_matmul\")\n        )\n\n    def _check(ctx: relax.transform.PatternCheckContext) -> bool:\n        return _check_decoding(ctx) and _check_matmul(ctx)\n\n    return matmul, annotations, _check\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/fuse_dequantize_take.py",
    "content": "\"\"\"A compiler pass that fuses dequantize + take.\"\"\"\n\nimport tvm\nfrom tvm import IRModule, relax, tir\nfrom tvm.relax.dpl.pattern import (\n    GlobalVarPattern,\n    TuplePattern,\n    is_const,\n    is_op,\n    wildcard,\n)\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"FuseDequantizeTake\")\nclass FuseDequantizeTake:  # pylint: disable=too-few-public-methods\n    \"\"\"A compiler pass that fuses dequantize + take.\"\"\"\n\n    def transform_module(  # pylint: disable=too-many-locals\n        self,\n        mod: IRModule,\n        _ctx: tvm.transform.PassContext,\n    ) -> IRModule:\n        \"\"\"IRModule-level transformation\"\"\"\n        seq = []\n        for n_aux_tensor in [2, 3]:\n            for match_tir_vars in [False, True]:\n                seq.append(\n                    relax.transform.FuseOpsByPattern(\n                        [\n                            (\n                                \"dequantize_take\",\n                                *_pattern(n_aux_tensor, match_tir_vars),\n                            )\n                        ]\n                    )\n                )\n        seq.append(relax.transform.FuseTIR())\n        mod = tvm.transform.Sequential(seq)(mod)\n        for g_var, func in mod.functions_items():\n            name = g_var.name_hint\n            if isinstance(func, tir.PrimFunc) and (\n                (\"fused_dequantize\" in name) and (\"take\" in name)\n            ):\n                sch_mod = tvm.IRModule({\"main\": func})\n                sch_mod = tir.transform.ForceNarrowIndexToInt32()(sch_mod)\n                sch = tvm.s_tir.Schedule(sch_mod)\n                sch.compute_inline(\"dequantize\")\n                mod[g_var] = sch.mod[\"main\"]\n        return mod\n\n\ndef _pattern(n_aux_tensor: int, match_tir_vars: bool):\n    dequantize = is_op(\"relax.call_tir\")(\n        GlobalVarPattern(),\n        TuplePattern([wildcard() for _ in range(n_aux_tensor)]),\n        add_constraint=False,\n    )\n    indices = ~is_const()\n    if match_tir_vars:\n        call_tir_args_take = [\n            GlobalVarPattern(),\n            TuplePattern([dequantize, indices]),\n            wildcard(),\n        ]\n    else:\n        call_tir_args_take = [\n            GlobalVarPattern(),\n            TuplePattern([dequantize, indices]),\n        ]\n    take = is_op(\"relax.call_tir\")(\n        *call_tir_args_take,\n        add_constraint=False,\n    )\n    annotations = {\n        \"take\": take,\n        \"dequantize\": dequantize,\n        \"indices\": indices,\n    }\n\n    def _check(ctx: relax.transform.PatternCheckContext) -> bool:\n        take = ctx.annotated_expr[\"take\"]\n        dequantize = ctx.annotated_expr[\"dequantize\"]\n        if not isinstance(dequantize, relax.expr.Call):\n            return False\n        if not isinstance(take.args[0], relax.GlobalVar) or not isinstance(\n            dequantize.args[0], relax.GlobalVar\n        ):\n            return False\n        return \"take\" in take.args[0].name_hint and \"dequantize\" in dequantize.args[0].name_hint\n\n    return take, annotations, _check\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/fuse_dequantize_transpose.py",
    "content": "\"\"\"A compiler pass that fuses transpose + dequantize.\"\"\"\n\nimport tvm\nfrom tvm import relax, s_tir, tir\nfrom tvm.ir.module import IRModule\nfrom tvm.relax.analysis import remove_all_unused\nfrom tvm.relax.expr_functor import PyExprMutator, mutator\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"FuseDequantizeTranspose\")\nclass FuseDequantizeTranspose:  # pylint: disable=too-few-public-methods\n    \"\"\"A compiler pass that fuses transpose + dequantize.\"\"\"\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"IRModule-level transformation\"\"\"\n        return _DequantizeTransposeFuser(mod).transform()\n\n\n@mutator\nclass _DequantizeTransposeFuser(PyExprMutator):  # pylint: disable=abstract-method\n    def __init__(\n        self,\n        mod: IRModule,\n    ):\n        super().__init__(mod)\n        self.mod = mod\n\n    def transform(self) -> IRModule:\n        \"\"\"Entry point\"\"\"\n        for g_var, func in self.mod.functions_items():\n            if isinstance(func, relax.Function):\n                updated_func = self.visit_expr(func)\n                updated_func = remove_all_unused(updated_func)\n                self.builder_.update_func(g_var, updated_func)\n        return self.builder_.get()\n\n    def visit_call_(  # pylint: disable=arguments-renamed\n        self,\n        call: relax.Call,\n    ) -> relax.Expr:\n        call = self.visit_expr_post_order(call)\n        if call.op != tvm.ir.Op.get(\"relax.matmul\"):\n            return call\n        # Do not fuse dequantize-transpose for GeMM\n        if (\n            call.args[0].struct_info.ndim < 2\n            or not isinstance(call.args[0].struct_info.shape[-2], tir.IntImm)\n            or call.args[0].struct_info.shape[-2].value != 1\n        ):\n            return call\n\n        matmul_rhs = self.lookup_binding(call.args[1])\n        if (\n            not isinstance(matmul_rhs, relax.Call)\n            or matmul_rhs.op != tvm.ir.Op.get(\"relax.permute_dims\")\n            or matmul_rhs.args[0].struct_info.ndim != 2\n            or matmul_rhs.attrs.axes is not None\n        ):\n            return call\n\n        transpose_input = self.lookup_binding(matmul_rhs.args[0])\n        if (\n            not isinstance(transpose_input, relax.Call)\n            or transpose_input.op != tvm.ir.Op.get(\"relax.call_tir\")\n            or not transpose_input.args[0].name_hint.startswith(\"dequantize\")\n            or not isinstance(transpose_input.struct_info, relax.TensorStructInfo)\n        ):\n            return call\n\n        dequantize_tir_func = self.mod[transpose_input.args[0]]\n        assert isinstance(dequantize_tir_func, tir.PrimFunc)\n        if (  # pylint: disable=too-many-boolean-expressions\n            len(dequantize_tir_func.body.block.alloc_buffers) != 1\n            or not isinstance(dequantize_tir_func.body.block.body, tir.SeqStmt)\n            or len(dequantize_tir_func.body.block.body) != 2\n            or not isinstance(dequantize_tir_func.body.block.body[1], tir.For)\n            or not isinstance(dequantize_tir_func.body.block.body[1].body.body, tir.SBlockRealize)\n            or dequantize_tir_func.body.block.body[1].body.body.block.name_hint != \"T_transpose\"\n        ):\n            return call\n\n        new_func_buffers = [\n            dequantize_tir_func.buffer_map[var] for var in dequantize_tir_func.params\n        ]\n        new_func_buffers[-1] = dequantize_tir_func.body.block.alloc_buffers[0]\n        new_func = tir.PrimFunc(\n            params=new_func_buffers,\n            body=tir.SBlockRealize(\n                iter_values=[],\n                predicate=True,\n                block=tir.SBlock(\n                    iter_vars=[],\n                    reads=[],\n                    writes=[],\n                    name_hint=\"root\",\n                    body=dequantize_tir_func.body.block.body[0],\n                ),\n            ),\n        )\n        # Call `renew_defs` for deep-copy to avoid IR node duplication in\n        # different PrimFuncs of an IRModule.\n        new_func = s_tir.renew_defs(new_func)\n        g_var = self.builder_.add_func(new_func, func_name=\"dequantize\")\n        dequantize_matmul_rhs = self.builder_.emit(\n            relax.call_tir(g_var, transpose_input.args[1], out_sinfo=matmul_rhs.struct_info)\n        )\n        return relax.op.matmul(call.args[0], dequantize_matmul_rhs, out_dtype=call.attrs.out_dtype)\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/fuse_ft_dequantize_matmul_epilogue.py",
    "content": "\"\"\"A compiler pass that fuses dequantize matmul + epilogue.\"\"\"\n\nimport operator\nfrom functools import reduce\n\nimport tvm\nfrom tvm import IRModule, relax\nfrom tvm.relax.dpl import rewrite_call\nfrom tvm.relax.dpl.pattern import is_op, wildcard\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"FuseDequantizeEpilogue\")\nclass FuseFTDequantizeEpilogue:  # pylint: disable=too-few-public-methods\n    \"\"\"A compiler pass that fuses FasterTransformer dequantize matmul + epilogue.\"\"\"\n\n    def transform_module(\n        self,\n        mod: IRModule,\n        _ctx: tvm.transform.PassContext,\n    ) -> IRModule:\n        \"\"\"IRModule-level transformation\"\"\"\n        for gv, func in mod.functions_items():\n            if isinstance(func, relax.Function):\n                func = fuse_bias(func)\n                func = fuse_activation(func)\n                func = fuse_residual_binary(func)\n                func = fuse_residual_unary(func)\n                mod[gv] = func\n        return mod\n\n\ndef fuse_bias(func: relax.Function) -> relax.Function:\n    \"\"\"\n    Fuse following `relax.add` into fastertransformer.gemm_fp16_int as bias:\n\n    Before:\n    ```\n    lv1 = relax.call_dps_packed(\"fastertransformer.gemm_fp16_int\", ...)\n    lv2 = relax.add(lv1, bias)\n\n    ```\n    After:\n    ```\n    lv2 = relax.call_dps_packed(\"fastertransformer.gemm_fp16_int_bias\", ..., bias, ...)\n    ```\n\n    Parameters\n    ----------\n    func : relax.Function\n        The function before fusion.\n\n    Returns\n    -------\n    ret : relax.Function\n        The function after fusion.\n    \"\"\"\n    decode_matmul = is_op(\"relax.call_dps_packed\")(varg_default_wildcard=True)\n    bias = wildcard()\n    pattern = is_op(\"relax.add\")(decode_matmul, bias) | is_op(\"relax.add\")(bias, decode_matmul)\n\n    def rewriter(expr, match):\n        if match[decode_matmul].args[0].global_symbol == \"fastertransformer.gemm_fp16_int\":\n            assert len(match[decode_matmul].args) == 2\n            args_list = match[decode_matmul].args[1]\n            assert len(args_list) == 8\n            if not args_list[3].value == \"identity\":\n                # bias cannot be fused after activation\n                return expr\n            matched_bias = match[bias]\n            bias_stride = (\n                matched_bias.struct_info.shape[-1]\n                if bias\n                and not reduce(operator.mul, matched_bias.struct_info.shape, 1)\n                == matched_bias.struct_info.shape[-1]\n                else 0\n            )\n            return relax.call_dps_packed(\n                \"fastertransformer.gemm_fp16_int_bias\",\n                [\n                    args_list[0],  # x\n                    args_list[1],  # weight\n                    args_list[2],  # scale\n                    matched_bias,  # bias\n                    args_list[3],  # activation\n                    args_list[4],  # m\n                    args_list[5],  # n\n                    args_list[6],  # k\n                    args_list[7],  # group_size\n                    bias_stride,  # bias_stride\n                ],\n                out_sinfo=match[decode_matmul].struct_info,\n            )\n        return expr\n\n    return rewrite_call(pattern, rewriter, func)\n\n\ndef fuse_activation(func: relax.Function) -> relax.Function:\n    \"\"\"\n    Fuse following `relax.nn.silu/relu/gelu` into fastertransformer.gemm_fp16_int_bias\n    as activation:\n\n    Before:\n    ```\n    lv1 = relax.call_dps_packed(\"fastertransformer.gemm_fp16_int_bias\", ...)\n    lv2 = relax.silu(lv1)\n\n    ```\n    After:\n    ```\n    lv2 = relax.call_dps_packed(\"fastertransformer.gemm_fp16_int_bias\", ..., \"silu\", ...)\n    ```\n\n    Parameters\n    ----------\n    func : relax.Function\n        The function before fusion.\n\n    Returns\n    -------\n    ret : relax.Function\n        The function after fusion.\n    \"\"\"\n    # pylint: disable=unsupported-binary-operation\n    decode_matmul = is_op(\"relax.call_dps_packed\")(varg_default_wildcard=True)\n    pattern = (\n        is_op(\"relax.nn.silu\")(decode_matmul)\n        | is_op(\"relax.nn.gelu\")(decode_matmul)\n        | is_op(\"relax.nn.relu\")(decode_matmul)\n    )\n\n    def rewriter(expr, match):\n        if match[decode_matmul].args[0].global_symbol == \"fastertransformer.gemm_fp16_int\":\n            matched_activation = match[pattern]\n            assert matched_activation.op.name in [\n                \"relax.nn.silu\",\n                \"relax.nn.gelu\",\n                \"relax.nn.relu\",\n            ]\n            assert len(match[decode_matmul].args) == 2\n            args_list = match[decode_matmul].args[1]\n            assert len(args_list) == 8\n            return relax.call_dps_packed(\n                \"fastertransformer.gemm_fp16_int\",\n                [\n                    args_list[0],  # x\n                    args_list[1],  # weight\n                    args_list[2],  # scale\n                    matched_activation.op.name[9:],  # activation\n                    args_list[4],  # m\n                    args_list[5],  # n\n                    args_list[6],  # k\n                    args_list[7],  # group_size\n                ],\n                out_sinfo=match[decode_matmul].struct_info,\n            )\n        if match[decode_matmul].args[0].global_symbol == \"fastertransformer.gemm_fp16_int_bias\":\n            matched_activation = match[pattern]\n            assert matched_activation.op.name in [\n                \"relax.nn.silu\",\n                \"relax.nn.gelu\",\n                \"relax.nn.relu\",\n            ]\n            assert len(match[decode_matmul].args) == 2\n            args_list = match[decode_matmul].args[1]\n            assert len(args_list) == 10\n            return relax.call_dps_packed(\n                \"fastertransformer.gemm_fp16_int_bias\",\n                [\n                    args_list[0],  # x\n                    args_list[1],  # weight\n                    args_list[2],  # scale\n                    args_list[3],  # bias\n                    matched_activation.op.name[9:],  # activation\n                    args_list[5],  # m\n                    args_list[6],  # n\n                    args_list[7],  # k\n                    args_list[8],  # group_size\n                    args_list[9],  # bias_stride\n                ],\n                out_sinfo=match[decode_matmul].struct_info,\n            )\n        return expr\n\n    return rewrite_call(pattern, rewriter, func)\n\n\ndef fuse_residual_binary(func: relax.Function) -> relax.Function:\n    \"\"\"\n    Fuse following `relax.add/multiply` into fastertransformer.gemm_fp16_int_bias as\n    residual binary operation:\n\n    Before:\n    ```\n    lv1 = relax.call_dps_packed(\"fastertransformer.gemm_fp16_int_bias\", ...)\n    lv2 = relax.add(lv1, residual)\n\n    ```\n    After:\n    ```\n    lv2 = relax.call_dps_packed(\n        \"fastertransformer.gemm_fp16_int_bias_residual\",\n        ...,\n        residual,\n        ...,\n        \"plus\",\n        ...\n    )\n    ```\n\n    Parameters\n    ----------\n    func : relax.Function\n        The function before fusion.\n\n    Returns\n    -------\n    ret : relax.Function\n        The function after fusion.\n    \"\"\"\n    # pylint: disable=unsupported-binary-operation\n    decode_matmul = is_op(\"relax.call_dps_packed\")(varg_default_wildcard=True)\n    residual = wildcard()\n    pattern = (\n        is_op(\"relax.add\")(decode_matmul, residual)\n        | is_op(\"relax.add\")(residual, decode_matmul)\n        | is_op(\"relax.multiply\")(decode_matmul, residual)\n        | is_op(\"relax.multiply\")(residual, decode_matmul)\n    )\n\n    def rewriter(expr, match):\n        if match[decode_matmul].args[0].global_symbol == \"fastertransformer.gemm_fp16_int_bias\":\n            matched_binary = match[pattern]\n            assert matched_binary.op.name in [\"relax.add\", \"relax.multiply\"]\n            binary_op = \"plus\" if matched_binary.op.name == \"relax.add\" else \"multiply\"\n            assert len(match[decode_matmul].args) == 2\n            args_list = match[decode_matmul].args[1]\n            assert len(args_list) == 10\n            matched_residual = match[residual]\n            if not args_list[9].value == 0:\n                # fastertransformer.gemm_fp16_int_bias_residual does not support\n                # bias_stride != 0 yet\n                return expr\n            return relax.call_dps_packed(\n                \"fastertransformer.gemm_fp16_int_bias_residual\",\n                [\n                    args_list[0],  # x\n                    args_list[1],  # weight\n                    args_list[2],  # scale\n                    args_list[3],  # bias\n                    matched_residual,  # residual\n                    args_list[4],  # activation\n                    binary_op,  # binary_op\n                    \"identity\",  # unary_op\n                    args_list[5],  # m\n                    args_list[6],  # n\n                    args_list[7],  # k\n                    args_list[8],  # group_size\n                ],\n                out_sinfo=match[decode_matmul].struct_info,\n            )\n        return expr\n\n    return rewrite_call(pattern, rewriter, func)\n\n\ndef fuse_residual_unary(func: relax.Function) -> relax.Function:\n    \"\"\"\n    Fuse following `relax.nn.silu/relu/gelu` into fastertransformer.gemm_fp16_int_bias_residual\n    as residual unary operation:\n\n    Before:\n    ```\n    lv1 = relax.call_dps_packed(\"fastertransformer.gemm_fp16_int_bias_residual\", ...)\n    lv2 = relax.silu(lv1)\n\n    ```\n    After:\n    ```\n    lv2 = relax.call_dps_packed(\"fastertransformer.gemm_fp16_int_bias_residual\", ..., \"silu\", ...)\n    ```\n\n    Parameters\n    ----------\n    func : relax.Function\n        The function before fusion.\n\n    Returns\n    -------\n    ret : relax.Function\n        The function after fusion.\n    \"\"\"\n    # pylint: disable=unsupported-binary-operation\n    decode_matmul = is_op(\"relax.call_dps_packed\")(varg_default_wildcard=True)\n    pattern = (\n        is_op(\"relax.nn.silu\")(decode_matmul)\n        | is_op(\"relax.nn.gelu\")(decode_matmul)\n        | is_op(\"relax.nn.relu\")(decode_matmul)\n    )\n\n    def rewriter(expr, match):\n        if (\n            match[decode_matmul].args[0].global_symbol\n            == \"fastertransformer.gemm_fp16_int_bias_residual\"\n        ):\n            matched_activation = match[pattern]\n            assert matched_activation.op.name in [\n                \"relax.nn.silu\",\n                \"relax.nn.gelu\",\n                \"relax.nn.relu\",\n            ]\n            assert len(match[decode_matmul].args) == 2\n            args_list = match[decode_matmul].args[1]\n            assert len(args_list) == 12\n            return relax.call_dps_packed(\n                \"fastertransformer.gemm_fp16_int_bias_residual\",\n                [\n                    args_list[0],  # x\n                    args_list[1],  # weight\n                    args_list[2],  # scale\n                    args_list[3],  # bias\n                    args_list[4],  # residual\n                    args_list[5],  # activation\n                    args_list[6],  # binary_op\n                    matched_activation.op.name[9:],  # activation\n                    args_list[8],  # m\n                    args_list[9],  # n\n                    args_list[10],  # k\n                    args_list[11],  # group_size\n                ],\n                out_sinfo=match[decode_matmul].struct_info,\n            )\n        return expr\n\n    return rewrite_call(pattern, rewriter, func)\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/fuse_transpose_matmul.py",
    "content": "\"\"\"A compiler pass that fuses transpose + matmul.\"\"\"\n\nimport tvm\nfrom tvm import IRModule, relax, te, tir\nfrom tvm.relax.dpl.pattern import is_op, wildcard\nfrom tvm.relax.expr_functor import PyExprMutator, mutator\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"FuseTransposeMatmul\")\nclass FuseTransposeMatmul:  # pylint: disable=too-few-public-methods\n    \"\"\"A compiler pass that fuses transpose + matmul.\"\"\"\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"IRModule-level transformation\"\"\"\n        mod = relax.transform.FuseOpsByPattern(\n            [\n                (\n                    \"transpose_matmul_fuse\",\n                    *_pattern(),\n                ),\n            ]\n        )(mod)\n        transpose_matmul_codegen = _TransposeMatmulFuser(mod)\n        for g_var, func in mod.functions_items():\n            if isinstance(func, relax.Function):\n                func = transpose_matmul_codegen.visit_expr(func)\n                transpose_matmul_codegen.builder_.update_func(g_var, func)\n        return transpose_matmul_codegen.builder_.get()\n\n\ndef _pattern():\n    \"\"\"Pattern for transpose + matmul.\"\"\"\n    # pylint: disable=invalid-name\n    w = wildcard()\n    x = wildcard()\n    wT = is_op(\"relax.permute_dims\")(w)\n    o = is_op(\"relax.matmul\")(x, wT)\n    # pylint: enable=invalid-name\n    annotations = {\"o\": o, \"w\": w, \"x\": x, \"wT\": wT}\n\n    def _check(context: relax.transform.PatternCheckContext) -> bool:\n        transpose_call = context.annotated_expr[\"wT\"]\n        ndim = transpose_call.args[0].struct_info.ndim\n        if ndim == -1:\n            return False\n        if ndim == 2 and transpose_call.attrs.axes is None:\n            return True\n        axes = list(range(ndim))\n        axes[-1], axes[-2] = axes[-2], axes[-1]\n        return list(transpose_call.attrs.axes) == axes\n\n    return o, annotations, _check\n\n\n# pylint: disable=missing-docstring,invalid-name\n\n\n@mutator\nclass _TransposeMatmulFuser(PyExprMutator):  # pylint: disable=abstract-method\n    def __init__(self, mod):\n        super().__init__(mod)\n\n    def visit_call_(  # pylint: disable=arguments-renamed\n        self,\n        call: relax.Call,\n    ) -> relax.Expr:\n        out_dtype = None\n\n        def te_transposed_matmul(a: te.Tensor, b: te.Tensor) -> te.Tensor:\n            nonlocal out_dtype\n            a_shape = list(a.shape)\n            b_shape = list(b.shape)\n            a_prepended = False\n            b_appended = False\n            if len(a_shape) == 1:\n                a_prepended = True\n                a_shape.insert(0, 1)\n            if len(b_shape) == 1:\n                b_appended = True\n                b_shape.append(1)\n\n            is_a_larger = len(a_shape) > len(b_shape)\n            offset = len(a_shape) - len(b_shape) if is_a_larger else len(b_shape) - len(a_shape)\n\n            a_relax = relax.Var(\"a\", relax.TensorStructInfo(a.shape))\n            bT_shape = list(b.shape)\n            bT_shape[-1], bT_shape[-2] = bT_shape[-2], bT_shape[-1]\n            bT_relax = relax.Var(\"b\", relax.TensorStructInfo(bT_shape))\n            output_shape = self.builder_.normalize(\n                relax.op.matmul(a_relax, bT_relax)\n            ).struct_info.shape\n\n            def matmul_compute(*idx_spatial):\n                k = te.reduce_axis((0, a_shape[-1]), name=\"k\")\n\n                def multiply_compute(idx_reduce):\n                    a_indices = []\n                    b_indices = []\n\n                    for i in range(offset):\n                        if is_a_larger:\n                            a_indices.append(idx_spatial[i])\n                        else:\n                            b_indices.append(idx_spatial[i])\n                    for i in range(offset, len(output_shape) - (2 - a_prepended - b_appended)):\n                        a_dim = a_shape[i if is_a_larger else i - offset]\n                        b_dim = b_shape[i if not is_a_larger else i - offset]\n                        dim_equal = a_dim == b_dim\n                        if not isinstance(dim_equal, tir.IntImm) or dim_equal == 0:\n                            a_dim_is_one = isinstance(a_dim, tir.IntImm) and a_dim == 1\n                            b_dim_is_one = isinstance(b_dim, tir.IntImm) and b_dim == 1\n                            a_indices.append(0 if a_dim_is_one else idx_spatial[i])\n                            b_indices.append(0 if b_dim_is_one else idx_spatial[i])\n                        else:\n                            a_indices.append(idx_spatial[i])\n                            b_indices.append(idx_spatial[i])\n\n                    if not a_prepended:\n                        a_indices.append(idx_spatial[-2 + b_appended])\n                    a_indices.append(idx_reduce)\n                    if not b_appended:\n                        b_indices.append(idx_spatial[-1])\n                    b_indices.append(idx_reduce)\n\n                    dtype = out_dtype\n                    if dtype != \"\":\n                        return a(*a_indices).astype(dtype) * b(*b_indices).astype(dtype)\n                    return a(*a_indices) * b(*b_indices)\n\n                return te.sum(multiply_compute(k), axis=k)\n\n            return te.compute(\n                output_shape,\n                lambda *idx: matmul_compute(*idx),  # pylint: disable=unnecessary-lambda\n                name=\"NT_matmul\",\n            )\n\n        if isinstance(call.op, relax.GlobalVar):\n            function = self.builder_.get()[call.op]\n            if (\n                \"Composite\" in function.attrs\n                and function.attrs[\"Composite\"] == \"transpose_matmul_fuse\"\n            ):\n                out_dtype = function.ret_struct_info.dtype\n                return self.builder_.call_te(\n                    te_transposed_matmul,\n                    call.args[1],\n                    call.args[0],\n                    primfunc_name_hint=\"NT_matmul\",\n                )\n\n        return super().visit_call_(call)\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/lift_global_buffer_alloc.py",
    "content": "\"\"\"A compiler pass that lifts TIR-level global allocation to Relax.\"\"\"\n\nfrom typing import Dict, List, Tuple\n\nimport tvm\nfrom tvm import relax, tir\nfrom tvm.ir.module import IRModule\nfrom tvm.relax.analysis import remove_all_unused\nfrom tvm.relax.expr_functor import PyExprMutator, mutator\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"LiftTIRGlobalBufferAlloc\")\nclass LiftTIRGlobalBufferAlloc:  # pylint: disable=too-few-public-methods\n    \"\"\"A compiler pass that lifts TIR-level global allocation to Relax.\"\"\"\n\n    def transform_module(\n        self,\n        mod: IRModule,\n        _ctx: tvm.transform.PassContext,\n    ) -> IRModule:\n        \"\"\"IRModule-level transformation\"\"\"\n        return _TIRGlobalAllocRewriter(mod).transform()\n\n\n@mutator\nclass _TIRGlobalAllocRewriter(PyExprMutator):  # pylint: disable=abstract-method\n    def __init__(self, mod: IRModule):\n        super().__init__(mod)\n        self.mod = mod\n        self.gv2new_tensor_sinfo: Dict[\n            tvm.ir.GlobalVar,\n            Tuple[tvm.ir.GlobalVar, List[relax.TensorStructInfo], tir.PrimFunc],\n        ] = {}\n\n    def transform(self) -> IRModule:\n        \"\"\"Entry point of the transformation\"\"\"\n        for g_var, func in self.mod.functions_items():\n            if isinstance(func, tir.PrimFunc):\n                updated_func, tensor_sinfo_list = remove_global_buf_alloc(func)\n                if len(tensor_sinfo_list) > 0:\n                    new_gv = self.builder_.add_func(updated_func, g_var.name_hint)\n                    self.gv2new_tensor_sinfo[g_var] = (new_gv, tensor_sinfo_list, func)\n\n        self.mod = self.builder_.get()\n        for g_var, func in self.mod.functions_items():\n            if isinstance(func, relax.Function):\n                updated_func = self.visit_expr(func)\n                updated_func = remove_all_unused(updated_func)\n                self.builder_.update_func(g_var, updated_func)\n\n        mod = self.builder_.get()\n        return relax.transform.DeadCodeElimination()(mod)\n\n    def visit_call_(self, call: relax.Call):  # pylint: disable=arguments-renamed\n        call = self.visit_expr_post_order(call)\n        if (\n            call.op != tvm.ir.Op.get(\"relax.call_tir\")\n            or call.args[0] not in self.gv2new_tensor_sinfo\n        ):\n            return call\n\n        g_var = call.args[0]\n        new_gv, tensor_sinfo, func_before_update = self.gv2new_tensor_sinfo[g_var]\n\n        assert len(call.sinfo_args) == 1\n        if any(_has_symbolic_var(sinfo) for sinfo in tensor_sinfo):\n            tensor_sinfo, success = _resolve_tir_var_mapping(func_before_update, call, tensor_sinfo)\n            if not success:\n                # Cannot resolve TIR var mapping. Fall back to no lifting.\n                self.gv2new_tensor_sinfo.pop(g_var)\n                return call\n\n        args = list(call.args)\n        args[0] = new_gv\n        if isinstance(call.sinfo_args[0], relax.TensorStructInfo):\n            new_call = relax.Call(\n                call.op,\n                args=args,\n                sinfo_args=[relax.TupleStructInfo(list(call.sinfo_args) + tensor_sinfo)],\n                attrs=call.attrs,\n            )\n            emitted_tuple = self.builder_.emit(new_call)\n            return relax.TupleGetItem(emitted_tuple, 0)\n        assert isinstance(call.sinfo_args[0], relax.TupleStructInfo)\n        return relax.Call(\n            call.op,\n            args=args,\n            sinfo_args=[relax.TupleStructInfo(list(call.sinfo_args[0].fields) + tensor_sinfo)],\n            attrs=call.attrs,\n        )\n\n\ndef remove_global_buf_alloc(\n    func: tir.PrimFunc,\n) -> Tuple[tir.PrimFunc, List[relax.TensorStructInfo]]:\n    \"\"\"Remove the global buffer allocation for a given TIR PrimFunc.\"\"\"\n    assert isinstance(func.body, tir.SBlockRealize)\n    params = list(func.params)\n    buffer_map = dict(func.buffer_map)\n    tensor_sinfo = []\n    alloc_buffers = []\n\n    insertion_point = len(params)\n    while params[insertion_point - 1].dtype != \"handle\":\n        insertion_point -= 1\n        assert insertion_point >= 1\n\n    prev_root_block = func.body.block\n    for buf_alloc in func.body.block.alloc_buffers:\n        if buf_alloc.scope() == \"global\":\n            param = tir.Var(\"var_\" + buf_alloc.name, \"handle\")\n            params.insert(insertion_point, param)\n            insertion_point += 1\n            buffer_map[param] = buf_alloc\n            tensor_sinfo.append(relax.TensorStructInfo(buf_alloc.shape, buf_alloc.dtype))\n        else:\n            alloc_buffers.append(buf_alloc)\n\n    if len(tensor_sinfo) == 0:\n        return func, []\n\n    assert len(prev_root_block.iter_vars) == 0\n    assert len(prev_root_block.reads) == 0\n    assert len(prev_root_block.writes) == 0\n    assert len(prev_root_block.match_buffers) == 0\n    assert prev_root_block.name_hint == \"root\"\n    assert prev_root_block.init is None\n    root_block = tir.SBlock(\n        iter_vars=[],\n        reads=[],\n        writes=[],\n        name_hint=\"root\",\n        body=prev_root_block.body,\n        alloc_buffers=alloc_buffers,\n        annotations=prev_root_block.annotations,\n    )\n\n    updated_func = tir.PrimFunc(\n        params=params,\n        body=tir.SBlockRealize(iter_values=[], predicate=True, block=root_block),\n        ret_type=func.ret_type,\n        buffer_map=buffer_map,\n        attrs=func.attrs,\n    )\n    return updated_func, tensor_sinfo\n\n\ndef _has_symbolic_var(tensor_sinfo: relax.TensorStructInfo) -> bool:\n    assert isinstance(tensor_sinfo.shape, relax.ShapeExpr)\n    for dim in tensor_sinfo.shape.values:\n        if not isinstance(dim, tir.IntImm):\n            return True\n    return False\n\n\ndef _resolve_tir_var_mapping(  # pylint: disable=too-many-locals\n    func: tir.PrimFunc,\n    call: relax.Call,\n    tensor_sinfo: List[relax.TensorStructInfo],\n) -> Tuple[List[relax.TensorStructInfo], bool]:\n    \"\"\"Resolve the TIR symbolic var relationship across sides of PrimFunc and Relax Function\"\"\"\n    var_map: Dict[tir.Var, tir.PrimExpr] = {}\n\n    n_arg = len(call.args[1].fields)\n    for i in range(n_arg):\n        buffer_shape = func.buffer_map[func.params[i]].shape\n        arg_shape = call.args[1][i].struct_info.shape.values\n        assert len(buffer_shape) == len(arg_shape)\n        for v_l, v_r in zip(buffer_shape, arg_shape):\n            if isinstance(v_l, tir.Var):\n                var_map[v_l] = v_r\n            elif not isinstance(v_l, tir.IntImm):\n                return [], False\n\n    ret_tensors = call.sinfo_args[0]\n    ret_tensors = (\n        [ret_tensors]  # type: ignore[assignment]\n        if isinstance(ret_tensors, relax.TensorStructInfo)\n        else list(ret_tensors.fields)\n    )\n    for i, ret_tensor in enumerate(ret_tensors):\n        buffer_shape = func.buffer_map[func.params[n_arg + i]].shape\n        ret_tensor_shape = ret_tensor.shape.values\n        assert len(buffer_shape) == len(ret_tensor_shape)\n        for v_l, v_r in zip(buffer_shape, ret_tensor_shape):\n            if isinstance(v_l, tir.Var):\n                var_map[v_l] = v_r\n            elif not isinstance(v_l, tir.IntImm):\n                return [], False\n\n    updated_tensor_sinfo = []\n    for sinfo in tensor_sinfo:\n        if not _has_symbolic_var(sinfo):\n            updated_tensor_sinfo.append(sinfo)\n            continue\n        new_shape = []\n        for dim in sinfo.shape.values:\n            new_shape.append(tir.stmt_functor.substitute(dim, var_map))\n        updated_tensor_sinfo.append(relax.TensorStructInfo(new_shape, sinfo.dtype))\n    return updated_tensor_sinfo, True\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/low_batch_specialization.py",
    "content": "\"\"\"A compiler pass that dispatch low-batch-gemm to gemv schedule.\"\"\"\n\nimport tvm\nfrom tvm import tir\nfrom tvm.ir.module import IRModule\nfrom tvm.s_tir import dlight as dl\n\n# pylint: disable=too-many-locals,not-callable\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"LowBatchGemvSpecialize\")\nclass LowBatchGemvSpecialize:  # pylint: disable=too-few-public-methods\n    \"\"\"A compiler pass that dispatch low-batch-gemm to gemv schedule.\"\"\"\n\n    def transform_module(\n        self,\n        mod: IRModule,\n        _ctx: tvm.transform.PassContext,\n    ) -> IRModule:\n        \"\"\"IRModule-level transformation\"\"\"\n        for g_var, func in mod.functions_items():\n            if isinstance(func, tir.PrimFunc):\n                low_batch_range = [2, 8]\n                buckets = [2, 4]\n                low_batch_funcs = []\n                for bucket in buckets:\n                    low_batch_mod = IRModule({})\n                    low_batch_mod[\"main\"] = func\n                    low_batch_mod = dl.ApplyDefaultSchedule(\n                        dl.gpu.LowBatchGEMV(bucket),\n                    )(low_batch_mod)\n                    low_batch_funcs.append(low_batch_mod[\"main\"])\n                if any(\n                    tvm.ir.structural_equal(low_batch_func, func)\n                    for low_batch_func in low_batch_funcs\n                ):\n                    continue\n                buffers = func.buffer_map.values()\n                shapes = [buffer.shape for buffer in buffers]\n                symbolic_vars = set(\n                    expr for shape in shapes for expr in shape if isinstance(expr, tir.Var)\n                )\n                if len(symbolic_vars) != 1:\n                    continue\n                gemm_mod = IRModule({})\n                gemm_mod[\"main\"] = func\n                gemm_mod = dl.ApplyDefaultSchedule(\n                    dl.gpu.Matmul(),\n                )(gemm_mod)\n                gemm_func = gemm_mod[\"main\"]\n                sym_var = list(symbolic_vars)[0]\n                body = gemm_func.body\n                for i, range_limit in reversed(list(enumerate(low_batch_range))):\n                    body = tir.IfThenElse(\n                        tir.op.tvm_thread_invariant(sym_var <= range_limit),\n                        low_batch_funcs[i].body,\n                        body,\n                    )\n                body = tir.SBlock([], [], [], \"root\", body)\n                body = tir.SBlockRealize([], True, body)\n                new_func = func.with_body(body)\n                new_func = new_func.with_attr(\"tir.is_scheduled\", 1)\n                new_func = new_func.with_attr(\"tir.HoistIfThenElseExprWithBlock\", 1)\n                mod.update_func(g_var, new_func)\n        return mod\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/pipeline.py",
    "content": "\"\"\"The compilation pipeline for LLM applications.\"\"\"\n\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional\n\nimport tvm\nfrom tvm import IRModule\nfrom tvm.relax import register_pipeline  # pylint: disable=no-name-in-module\nfrom tvm.relax.frontend import nn\nfrom tvm.s_tir import dlight as dl\n\nfrom mlc_llm.interface.compiler_flags import IPCAllReduceStrategyType\nfrom mlc_llm.support import logging\n\nfrom .attach_cuda_graph_alloc_init_func import AttachCUDAGraphAllocInitFunc\nfrom .attach_embedding_allocator import AttachAllocEmbeddingTensorFunc\nfrom .attach_logit_processor import AttachLogitProcessFunc\nfrom .attach_sampler import AttachGPUSamplingFunc\nfrom .attach_softmax_with_temperature import AttachSoftmaxWithTemperature\nfrom .attach_spec_decode_aux_funcs import AttachSpecDecodeAuxFuncs\nfrom .attach_support_info import (\n    AttachAdditionalPrimFuncs,\n    AttachCUDAGraphSymbolicCaptureHints,\n    AttachMemoryPlanAttr,\n    AttachPipelineParallelStages,\n    AttachSequenceLengthPaddingFactor,\n    AttachVariableBounds,\n)\nfrom .blas_dispatch import BLASDispatch\nfrom .clean_up_tir_attrs import CleanUpTIRAttrs\nfrom .dispatch_kv_cache_creation import DispatchKVCacheCreation\nfrom .dispatch_triton_kernel import DispatchTritonKernel\nfrom .estimate_memory_usage import AttachMetadataWithMemoryUsage\nfrom .fuse_add_norm import FuseAddRMSNorm\nfrom .fuse_dequantize_matmul_ewise import FuseDequantizeMatmulEwise\nfrom .fuse_dequantize_take import FuseDequantizeTake\nfrom .fuse_dequantize_transpose import FuseDequantizeTranspose\nfrom .fuse_ft_dequantize_matmul_epilogue import FuseFTDequantizeEpilogue\nfrom .fuse_transpose_matmul import FuseTransposeMatmul\nfrom .lift_global_buffer_alloc import LiftTIRGlobalBufferAlloc\nfrom .low_batch_specialization import LowBatchGemvSpecialize\nfrom .pipeline_parallel_rewrite import PipelineParallelRewrite\nfrom .scatter_tuple_get_item import ScatterTupleGetItem\n\nlogger = logging.getLogger(__name__)\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"_LogProgress\")\nclass _LogProgress:  # pylint: disable=too-few-public-methods\n    \"\"\"A dummy compiler pass that does nothing but logging.\"\"\"\n\n    def __init__(self, *args):\n        self.args = args\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"A dummy transformation\"\"\"\n        logger.info(*self.args)\n        return mod\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"DebugDump\")\nclass _DebugDump:  # pylint: disable=too-few-public-methods\n    \"\"\"A dummy compiler pass that does nothing but logging.\n    Only enabled when debug_dump is not None\"\"\"\n\n    def __init__(self, file_name: str, file_path: Optional[Path], show_meta: bool = False):\n        self.file_name = file_name\n        self.file_path = file_path\n        self.show_meta = show_meta\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"A dummy transformation that dumps the module to file\"\"\"\n        if self.file_path is not None:\n            # NOTE: We use debug level here to avoid spamming the console\n            logger.debug(\"Dumping IR to %s\", self.file_path / self.file_name)\n            with open(self.file_path / self.file_name, \"w\", encoding=\"utf-8\") as f:\n                f.write(mod.script(show_meta=self.show_meta))\n        return mod\n\n\n@register_pipeline(\"mlc_llm\")\ndef _mlc_llm_pipeline(  # pylint: disable=too-many-arguments\n    target: tvm.target.Target,\n    flashinfer: bool = False,\n    cublas_gemm: bool = False,\n    faster_transformer: bool = False,  # pylint: disable=unused-argument\n    allreduce_strategy: IPCAllReduceStrategyType = IPCAllReduceStrategyType.NONE,\n    variable_bounds: Dict[str, int] = None,\n    cuda_graph_symbolic_capture_hints: Dict[str, List[str]] = None,\n    additional_tirs: Dict[str, tvm.tir.PrimFunc] = None,\n    metadata: Dict[str, Any] = None,\n    ext_mods: List[nn.ExternModule] = None,\n    debug_dump: Optional[Path] = None,\n):\n    variable_bounds = variable_bounds or {}\n    cuda_graph_symbolic_capture_hints = cuda_graph_symbolic_capture_hints or {}\n    additional_tirs = additional_tirs or {}\n    metadata = metadata or {}\n    ext_mods = ext_mods or []\n    tensor_parallel_shards = metadata.get(\"tensor_parallel_shards\", 1)\n\n    @tvm.transform.module_pass(opt_level=0)\n    def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:\n        seq = tvm.transform.Sequential(\n            [\n                # Phase 0. Add additional information for compilation and remove unused Relax func\n                DispatchKVCacheCreation(target, flashinfer, metadata),\n                AttachSoftmaxWithTemperature(target, metadata),\n                AttachVariableBounds(variable_bounds),\n                AttachCUDAGraphSymbolicCaptureHints(cuda_graph_symbolic_capture_hints),\n                AttachPipelineParallelStages(metadata[\"pipeline_parallel_stages\"]),\n                AttachLogitProcessFunc(target),\n                AttachAdditionalPrimFuncs(additional_tirs),\n                AttachAllocEmbeddingTensorFunc(metadata),\n                AttachGPUSamplingFunc(target, variable_bounds),\n                AttachSpecDecodeAuxFuncs(tensor_parallel_shards),\n                AttachMemoryPlanAttr(),\n                AttachSequenceLengthPaddingFactor(target, metadata),\n                tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False)),\n                _DebugDump(\"debug-phase0.py\", debug_dump, show_meta=False),\n                # Phase 1. Passes on high-level operator graph\n                _LogProgress(\"Running TVM Relax graph-level optimizations\"),\n                DispatchTritonKernel(target),\n                FuseFTDequantizeEpilogue(),\n                FuseDequantizeTranspose(),\n                BLASDispatch(target) if cublas_gemm else tvm.transform.Sequential([]),\n                (\n                    FuseAddRMSNorm(target=target)\n                    if target.kind.name != \"llvm\"\n                    else tvm.transform.Sequential([])\n                ),\n                FuseTransposeMatmul(),\n                _DebugDump(\"debug-phase1.py\", debug_dump, show_meta=False),\n                # Phase 2. Lowering to TIR, inherited TVM Relax's official \"zero\" pipeline\n                _LogProgress(\"Lowering to TVM TIR kernels\"),\n                tvm.relax.backend.DispatchSampling(),\n                tvm.relax.backend.DispatchSortScan(),\n                tvm.relax.transform.LegalizeOps(),\n                tvm.relax.transform.AnnotateTIROpPattern(),\n                tvm.relax.transform.FoldConstant(),\n                tvm.relax.transform.FuseOps(),\n                tvm.relax.transform.FuseTIR(),\n                _DebugDump(\"debug-phase2.py\", debug_dump, show_meta=False),\n                # Phase 3. Passes on TIR\n                _LogProgress(\"Running TVM TIR-level optimizations\"),\n                FuseDequantizeMatmulEwise(),\n                FuseDequantizeTake(),\n                tvm.relax.transform.DeadCodeElimination(),\n                CleanUpTIRAttrs([\"op_pattern\"]),\n                _DebugDump(\"debug-phase3.py\", debug_dump, show_meta=False),\n                # Phase 4. Low-level Optimizations\n                _LogProgress(\"Running TVM Dlight low-level optimizations\"),\n                LowBatchGemvSpecialize(),\n                (\n                    dl.ApplyDefaultSchedule(\n                        dl.gpu.Matmul(),\n                        dl.gpu.GEMV(),\n                        dl.gpu.Reduction(),\n                        dl.gpu.GeneralReduction(),\n                        dl.gpu.Fallback(),\n                    )\n                    if target.kind.name != \"llvm\"\n                    else dl.ApplyDefaultSchedule(\n                        dl.cpu.GEMV(),\n                    )\n                ),\n                _DebugDump(\"debug-phase4.py\", debug_dump, show_meta=False),\n                _LogProgress(\"Lowering to VM bytecode\"),\n                (\n                    LiftTIRGlobalBufferAlloc()\n                    if target.kind.name != \"llvm\"\n                    else tvm.transform.Sequential([])\n                ),\n                (\n                    tvm.tir.transform.ForceNarrowIndexToInt32()\n                    if target.kind.name != \"cuda\"\n                    else tvm.transform.Sequential([])\n                ),\n                ScatterTupleGetItem(),\n                PipelineParallelRewrite(),\n                tvm.relax.transform.RewriteDataflowReshape(),\n                tvm.relax.transform.ToNonDataflow(),\n                tvm.relax.transform.RemovePurityChecking(),\n                tvm.relax.transform.CallTIRRewrite(),\n                (\n                    tvm.relax.transform.IPCAllReduceRewrite(allreduce_strategy)\n                    if allreduce_strategy != IPCAllReduceStrategyType.NONE\n                    else tvm.transform.Sequential([])\n                ),\n                tvm.relax.transform.StaticPlanBlockMemory(),\n                AttachMetadataWithMemoryUsage(metadata),\n                _DebugDump(\"debug-phase5.py\", debug_dump, show_meta=False),\n                tvm.relax.transform.RewriteCUDAGraph(),\n                AttachCUDAGraphAllocInitFunc(),\n                tvm.relax.transform.LowerGPUIPCAllocStorage(),\n                tvm.relax.transform.LowerAllocTensor(),\n                tvm.relax.transform.KillAfterLastUse(),\n                tvm.relax.transform.LowerRuntimeBuiltin(),\n                tvm.relax.transform.VMShapeLower(),\n                tvm.relax.transform.AttachGlobalSymbol(),\n                _LogProgress(\"Compiling external modules\"),\n                tvm.relax.transform.AttachExternModules(ext_mods),\n                _LogProgress(\"Compilation complete! Exporting to disk\"),\n            ]\n        )\n        mod = seq(mod)\n        return mod\n\n    return _pipeline\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/pipeline_parallel_rewrite.py",
    "content": "\"\"\"A compiler pass that rewrites IR for pipeline parallelism.\"\"\"\n\nfrom typing import Dict, List, Optional, Tuple\n\nimport tvm\nfrom tvm import relax, tir\nfrom tvm.ir.module import IRModule\nfrom tvm.relax.expr_functor import PyExprMutator, PyExprVisitor, mutator, visitor\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"PipelineParallelRewrite\")\nclass PipelineParallelRewrite:  # pylint: disable=too-few-public-methods\n    \"\"\"A compiler pass that rewrites IR for pipeline parallelism.\"\"\"\n\n    def transform_module(\n        self,\n        mod: IRModule,\n        _ctx: tvm.transform.PassContext,\n    ) -> IRModule:\n        \"\"\"IRModule-level transformation\"\"\"\n        return _PipelineParallelRewriter(mod.clone()).transform()\n\n\n@mutator\nclass _PipelineParallelRewriter(PyExprMutator):  # pylint: disable=abstract-method\n    def __init__(self, mod: IRModule):\n        super().__init__(mod)\n        self.mod = mod\n        self.old_packed_params_var: relax.Var\n        self.new_main_packed_params_var: relax.Var\n        self.new_stage_func_packed_params: relax.Var\n        self.undefined_shape_vars_remap: Dict[tir.Var, tir.Var]\n        self.undefined_param_shape_vars_remap: Dict[tir.Var, tir.Var]\n\n    def transform(self) -> IRModule:  # pylint: disable=too-many-locals\n        \"\"\"Entry point of the transformation\"\"\"\n        for g_var, func in self.mod.functions_items():\n            if not isinstance(func, relax.Function) or \"pipeline_parallel_stages\" not in func.attrs:\n                continue\n            num_stages = int(func.attrs[\"pipeline_parallel_stages\"])\n            if num_stages == 1:\n                continue\n\n            pipeline_stages, stage_send_vars, stage_receive_vars = _extract_pipeline_stages(func)\n            assert len(pipeline_stages) == num_stages, (\n                \"Number of pipeline stages mismatches: \"\n                f\"expecting {num_stages} stages, but {len(pipeline_stages)} are found in the IR.\"\n            )\n\n            required_func_params = _analyze_required_func_params(pipeline_stages, func.params)\n\n            assert \"num_input\" in func.attrs\n            num_input = int(func.attrs[\"num_input\"])\n            assert (\n                len(func.params) == num_input + 1\n                and isinstance(func.params[num_input], relax.Var)\n                and func.params[num_input].name_hint == \"packed_params\"\n            ), 'Only the extra \"packed_params\" parameter is allowed'\n            self.old_packed_params_var = func.params[num_input]\n            self.new_main_packed_params_var = relax.Var(\"packed_params\", relax.ObjectStructInfo())\n            for required_params in required_func_params:\n                for i, param in enumerate(required_params):\n                    if param.same_as(self.old_packed_params_var):\n                        required_params.pop(i)\n                        break\n            func_output = func.body.body\n            assert isinstance(func_output, relax.Var)\n\n            stage_func_gvs = []\n            caller_args_list = []\n            for i in range(num_stages):\n                stage_func_gv, caller_args = self._create_stage_func(\n                    g_var.name_hint + f\"_stage{i}\",\n                    pipeline_stages[i],\n                    required_func_params[i],\n                    stage_receive_vars[i],\n                    stage_send_vars[i],\n                    func.attrs,\n                    func_output=func_output if i == num_stages - 1 else None,\n                )\n                stage_func_gvs.append(stage_func_gv)\n                caller_args_list.append(caller_args)\n\n            # Create and update the entry function, which dispatches toz the stage functions\n            # according to the disco worker group id.\n            bb = relax.BlockBuilder()\n            params = list(func.params[:-1]) + [self.new_main_packed_params_var]\n            with bb.function(g_var.name_hint, params=params):\n                dispatch_func_args = []\n                for stage_func_gv, caller_args in zip(stage_func_gvs, caller_args_list):\n                    dispatch_func_args.append([stage_func_gv] + caller_args)\n                output = bb.emit(\n                    relax.op.call_builtin_with_ctx(\n                        \"mlc.multi_gpu.DispatchFunctionByGroup\",\n                        args=[dispatch_func_args],\n                        sinfo_args=relax.ObjectStructInfo(),\n                    )\n                )\n                dispatch_func_gv = bb.emit_func_output(output)\n            dispatch_func = bb.finalize()[dispatch_func_gv]\n            self.builder_.update_func(g_var, dispatch_func)\n\n        return self.builder_.finalize()\n\n    def _create_stage_func(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        func_name: str,\n        stage_bindings: List[relax.Binding],\n        required_func_params: List[relax.Var],\n        stage_receive_vars: List[relax.Var],\n        stage_send_vars: List[relax.Var],\n        func_attrs: tvm.ir.DictAttrs,\n        func_output: Optional[relax.Var],\n    ) -> Tuple[tvm.ir.GlobalVar, List[relax.Expr]]:\n        self.undefined_shape_vars_remap = {}\n        self.undefined_param_shape_vars_remap = {}\n\n        # Prepare the func parameters (except the shape variables and packed params)\n        params, args = self._prepare_stage_func_params_and_args(required_func_params)\n        for new_param, old_param in zip(params, required_func_params):\n            self.set_var_remap(old_param.vid, new_param)\n        # Create new packed params\n        self.new_stage_func_packed_params = relax.Var(\"packed_params\", relax.ObjectStructInfo())\n        self.set_var_remap(self.old_packed_params_var.vid, self.new_stage_func_packed_params)\n\n        new_func_outputs = []\n        with self.builder_.function(func_name, pure=False):\n            with self.builder_.dataflow():\n                # Emit the tensors received from last stage.\n                for receive_var in stage_receive_vars:\n                    new_receive_var = self.builder_.emit(\n                        relax.call_dps_packed(\n                            \"runtime.disco.recv_from_prev_group\",\n                            args=[],\n                            out_sinfo=self._update_struct_info(receive_var.struct_info),\n                        ),\n                        name_hint=receive_var.name_hint,\n                    )\n                    self.set_var_remap(receive_var.vid, new_receive_var)\n                # Process the bindings in this stage.\n                for stage_binding in stage_bindings:\n                    if stage_binding.var in stage_send_vars or stage_binding.var.same_as(\n                        func_output\n                    ):\n                        assert isinstance(stage_binding, relax.VarBinding)\n                        new_var = self.builder_.emit_output(\n                            self.visit_expr(stage_binding.value),\n                            name_hint=stage_binding.var.name_hint,\n                        )\n                        self.set_var_remap(stage_binding.var.vid, new_var)\n                        new_func_outputs.append(new_var)\n                    else:\n                        self.visit_binding(stage_binding)\n            # Emit the calls to send tensors to the next stage.\n            for send_var in stage_send_vars:\n                new_send_var = self.get_var_remap(send_var.vid)\n                self.builder_.emit(\n                    relax.Call(\n                        relax.ExternFunc(\"runtime.disco.send_to_next_group\"),\n                        args=[new_send_var],\n                        sinfo_args=None,\n                    )\n                )\n            # Create the param for the shape variables.\n            shape_var_params = []\n            shape_var_args = []\n            for (\n                shape_var_arg,\n                shape_var_param,\n            ) in self.undefined_shape_vars_remap.items():\n                if shape_var_arg not in self.undefined_param_shape_vars_remap:\n                    shape_var_params.append(shape_var_param)\n                    shape_var_args.append(shape_var_arg)\n            params.append(relax.Var(\"s\", relax.ShapeStructInfo(shape_var_params)))\n            args.append(relax.ShapeExpr(shape_var_args))\n            # Add the packed params.\n            params.append(self.new_stage_func_packed_params)\n            args.append(self.new_main_packed_params_var)\n            # Conclude the function.\n            if func_output is not None:\n                assert len(new_func_outputs) == 1\n            new_gv = self.builder_.emit_func_output(\n                (\n                    new_func_outputs[0]\n                    if len(new_func_outputs) == 1\n                    and isinstance(new_func_outputs[0].struct_info, relax.TupleStructInfo)\n                    else new_func_outputs\n                ),\n                params=params,\n            )\n\n        new_func = (\n            self.builder_.get()[new_gv]\n            .with_attrs(func_attrs)\n            .with_attr(\"num_input\", len(params) - 1)\n            .without_attr(\"global_symbol\")\n            .without_attr(\"pipeline_parallel_stages\")\n        )\n        self.builder_.update_func(new_gv, new_func)\n        return new_gv, args\n\n    def visit_var_binding_(self, binding: relax.VarBinding) -> None:\n        if not isinstance(binding.value, relax.TupleGetItem):\n            super().visit_var_binding_(binding)\n            return\n\n        tuple_value = self.visit_expr(binding.value.tuple_value)\n        if not tuple_value.same_as(self.new_stage_func_packed_params):\n            super().visit_var_binding_(binding)\n            return\n\n        assert isinstance(binding.var.struct_info, relax.TensorStructInfo)\n        cur_num_undefined_param_shape_vars = len(self.undefined_param_shape_vars_remap)\n        new_tensor_struct_info = self._update_struct_info(\n            binding.var.struct_info, self.undefined_param_shape_vars_remap\n        )\n        has_new_undefined_shape_var = (\n            len(self.undefined_param_shape_vars_remap) != cur_num_undefined_param_shape_vars\n        )\n        self.undefined_shape_vars_remap = {\n            **self.undefined_shape_vars_remap,\n            **self.undefined_param_shape_vars_remap,\n        }\n        ret_sinfo = (\n            new_tensor_struct_info if not has_new_undefined_shape_var else relax.ObjectStructInfo()\n        )\n        call = relax.call_pure_packed(\n            \"vm.builtin.tuple_getitem\",\n            self.new_stage_func_packed_params,\n            relax.PrimValue(binding.value.index),\n            sinfo_args=ret_sinfo,\n        )\n        new_binding_var = self.builder_.emit(call, binding.var.name_hint)\n        if has_new_undefined_shape_var:\n            new_binding_var = self.builder_.match_cast(\n                new_binding_var, new_tensor_struct_info, binding.var.name_hint + \"_cast\"\n            )\n        self.set_var_remap(binding.var.vid, new_binding_var)\n\n    def visit_call_(self, call: relax.Call) -> relax.Call:  # pylint: disable=arguments-renamed\n        call = super().visit_call_(call)\n        return relax.Call(\n            call.op,\n            call.args,\n            call.attrs,\n            sinfo_args=[self._update_struct_info(struct_info) for struct_info in call.sinfo_args],\n        )\n\n    def _prepare_stage_func_params_and_args(\n        self, required_func_params: List[relax.Var]\n    ) -> Tuple[List[relax.Var], List[relax.Expr]]:\n        params: List[relax.Var] = []\n        args: List[relax.Expr] = []\n        for required_param in required_func_params:\n            struct_info = self._update_struct_info(required_param.struct_info)\n            params.append(relax.Var(required_param.name_hint, struct_info))\n            args.append(required_param)\n\n        return params, args\n\n    def _update_struct_info(\n        self,\n        struct_info: relax.StructInfo,\n        undefined_var_remap: Optional[Dict[tir.Var, tir.Var]] = None,\n    ) -> relax.StructInfo:\n        if undefined_var_remap is None:\n            undefined_var_remap = self.undefined_shape_vars_remap\n        if isinstance(struct_info, relax.TensorStructInfo):\n            return (\n                relax.TensorStructInfo(\n                    self._update_shape(struct_info.shape.values, undefined_var_remap),\n                    struct_info.dtype,\n                )\n                if struct_info.shape is not None and isinstance(struct_info.shape, relax.ShapeExpr)\n                else struct_info\n            )\n        if isinstance(struct_info, relax.ShapeStructInfo):\n            return (\n                relax.ShapeStructInfo(self._update_shape(struct_info.values, undefined_var_remap))\n                if struct_info.values is not None\n                else struct_info\n            )\n        if isinstance(struct_info, relax.ObjectStructInfo):\n            return relax.ObjectStructInfo()\n        if isinstance(struct_info, relax.TupleStructInfo):\n            return relax.TupleStructInfo(\n                [self._update_struct_info(field_sinfo) for field_sinfo in struct_info.fields]\n            )\n        return struct_info\n\n    def _copy_undefined_var(\n        self, expr: tir.PrimExpr, undefined_var_remap: Dict[tir.Var, tir.Var]\n    ) -> None:\n        def _visit_expr(e: tir.PrimExpr) -> None:\n            if isinstance(e, tir.Var) and e not in undefined_var_remap:\n                new_var = tir.Var(e.name, e.dtype)\n                undefined_var_remap[e] = new_var\n\n        tir.stmt_functor.post_order_visit(expr, _visit_expr)\n\n    def _update_shape(\n        self, shape: List[tir.PrimExpr], undefined_var_remap: Dict[tir.Var, tir.Var]\n    ) -> List[tir.PrimExpr]:\n        new_shape = []\n        for v in shape:\n            self._copy_undefined_var(v, undefined_var_remap)\n            new_shape.append(tir.stmt_functor.substitute(v, undefined_var_remap))\n        return new_shape\n\n\ndef _extract_pipeline_stages(\n    func: relax.Function,\n) -> Tuple[List[List[relax.Binding]], List[List[relax.Var]], List[List[relax.Var]]]:\n    pipeline_stages: List[List[relax.Binding]] = []\n    stage_send_vars: List[List[relax.Var]] = []\n    stage_receive_vars: List[List[relax.Var]] = []\n\n    # Requiring that the function has only one body block which is a dataflow block\n    assert isinstance(func.body, relax.SeqExpr)\n    assert len(func.body.blocks) == 1\n    assert isinstance(func.body.blocks[0], relax.DataflowBlock)\n    bindings = func.body.blocks[0].bindings\n\n    boundary_var = None\n    current_stage_bindings: List[relax.Binding] = []\n    current_stage_receive_vars: List[relax.Var] = []\n    for binding in bindings:\n        if (\n            isinstance(binding, relax.VarBinding)\n            and isinstance(binding.value, relax.Call)\n            and binding.value.op == tvm.ir.Op.get(\"relax.call_pure_packed\")\n            and binding.value.args[0].global_symbol == \"mlc.pipeline_parallel_stage_boundary\"\n        ):\n            assert len(current_stage_bindings) > 0\n            pipeline_stages.append(current_stage_bindings)\n            assert all(receive_var is not None for receive_var in current_stage_receive_vars)\n            stage_receive_vars.append(current_stage_receive_vars)\n            args = binding.value.args[1:]\n            assert len(args) >= 1 and all(isinstance(arg, relax.Var) for arg in args)\n            stage_send_vars.append(list(args))\n\n            boundary_var = binding.var\n            current_stage_bindings = []\n            current_stage_receive_vars = [boundary_var] if len(args) == 1 else [None for _ in args]\n        elif (\n            isinstance(binding, relax.VarBinding)\n            and isinstance(binding.value, relax.TupleGetItem)\n            and binding.value.tuple_value.same_as(boundary_var)\n        ):\n            current_stage_receive_vars[binding.value.index] = binding.var\n        else:\n            current_stage_bindings.append(binding)\n\n    assert len(current_stage_bindings) > 0\n    pipeline_stages.append(current_stage_bindings)\n    assert all(receive_var is not None for receive_var in current_stage_receive_vars)\n    stage_receive_vars.append(current_stage_receive_vars)\n    stage_send_vars.append([])\n\n    return pipeline_stages, stage_send_vars, stage_receive_vars\n\n\ndef _analyze_required_func_params(\n    pipeline_stages: List[List[relax.Binding]], func_params: List[relax.Var]\n) -> List[List[relax.Var]]:\n    analyzer = _RequiredFuncParamAnalyzer(func_params)\n    required_func_params: List[List[relax.Var]] = []\n    for stage_bindings in pipeline_stages:\n        required_params: List[relax.Var]\n        required_params = analyzer.run(stage_bindings)\n        required_func_params.append(required_params)\n    return required_func_params\n\n\n@visitor\nclass _RequiredFuncParamAnalyzer(PyExprVisitor):\n    \"\"\"The IR visitor which analyzes the required func parameters in each pipeline stage.\"\"\"\n\n    def __init__(self, func_params: List[relax.Var]) -> None:\n        self.func_params = set(func_params)\n        self.required_params: List[relax.Var]\n\n    def run(self, stage_bindings: List[relax.Binding]) -> List[relax.Var]:\n        \"\"\"Entry point of the visitor.\"\"\"\n        self.required_params = []\n        for binding in stage_bindings:\n            self.visit_binding(binding)\n        return self.required_params\n\n    def visit_var_(self, var: relax.Var) -> None:  # pylint: disable=arguments-renamed\n        if var in self.func_params:\n            if var not in self.required_params:\n                self.required_params.append(var)\n"
  },
  {
    "path": "python/mlc_llm/compiler_pass/scatter_tuple_get_item.py",
    "content": "\"\"\"A compiler pass that scatters TupleGetItem for lazy TupleGetItems.\"\"\"\n\nfrom typing import Dict\n\nimport tvm\nfrom tvm import relax\nfrom tvm.ir.module import IRModule\nfrom tvm.relax.analysis import remove_all_unused\nfrom tvm.relax.expr import Expr, Var\nfrom tvm.relax.expr_functor import PyExprMutator, mutator\n\n\n@tvm.transform.module_pass(opt_level=0, name=\"ScatterTupleGetItem\")\nclass ScatterTupleGetItem:  # pylint: disable=too-few-public-methods\n    \"\"\"A compiler pass that scatters TupleGetItem for lazy TupleGetItems.\"\"\"\n\n    def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:\n        \"\"\"IRModule-level transformation\"\"\"\n        return _Scatter(mod).transform()\n\n\n@mutator\nclass _Scatter(PyExprMutator):  # pylint: disable=abstract-method\n    def __init__(self, mod: IRModule) -> None:\n        super().__init__(mod)\n        self.mod = mod\n        self.var_map: Dict[Var, Expr] = {}\n\n    def transform(self) -> IRModule:\n        \"\"\"Entry point\"\"\"\n        for g_var, func in self.mod.functions_items():\n            if isinstance(func, relax.Function):\n                updated_func = self.visit_expr(func)\n                updated_func = remove_all_unused(updated_func)\n                self.builder_.update_func(g_var, updated_func)\n        return self.builder_.get()\n\n    def visit_var_binding_(self, binding: relax.VarBinding):\n        super().visit_var_binding_(binding)\n        if isinstance(binding.value, relax.TupleGetItem):\n            self.var_map[binding.var] = binding.value\n\n    def visit_dataflow_var_(  # pylint: disable=arguments-renamed\n        self, var: relax.DataflowVar\n    ) -> Expr:\n        if var in self.var_map:\n            new_var = self.builder_.emit(self.var_map[var], name_hint=var.name_hint)\n            self.set_var_remap(var.vid, new_var)\n            self.var_map.pop(var)\n            return new_var\n        return var\n"
  },
  {
    "path": "python/mlc_llm/contrib/__init__.py",
    "content": "\"\"\"Set of experimental components that yet to be matured.\"\"\"\n"
  },
  {
    "path": "python/mlc_llm/contrib/embeddings/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/contrib/embeddings/embeddings.py",
    "content": "\"\"\"The Python API for MLC Embeddings.\"\"\"\n\nimport json\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional, Tuple\n\nimport numpy as np\nimport tvm\nimport tvm_ffi\nfrom tvm import relax\nfrom tvm.contrib import tvmjs\nfrom tvm.runtime import Device, Module\nfrom tvm.runtime.vm import VirtualMachine\n\nfrom mlc_llm.serve import engine_utils\nfrom mlc_llm.support.auto_device import detect_device\nfrom mlc_llm.tokenizers import Tokenizer\n\n\ndef _extract_metadata(mod: Module):\n    return json.loads(VirtualMachine(mod, tvm.runtime.device(\"cpu\"))[\"_metadata\"]())\n\n\ndef _load_params(\n    model_weight_path: str, device: Device, model_metadata: Dict[str, Any]\n) -> List[tvm.runtime.Tensor]:\n    params, meta = tvmjs.load_tensor_cache(model_weight_path, device)\n    param_names = [param[\"name\"] for param in model_metadata[\"params\"]]\n    assert len(param_names) == meta[\"ParamSize\"]\n\n    plist = []\n    for param_name in param_names:\n        plist.append(params[param_name])\n    return plist\n\n\ndef _get_tvm_module(\n    model_weight_path: str,\n    lib_path: str,\n    device: Device,\n    instrument: tvm_ffi.Function = None,\n):\n    ex = tvm.runtime.load_module(lib_path)\n    vm = relax.VirtualMachine(ex, device)\n    if instrument:\n        vm.set_instrument(instrument)\n    metadata = _extract_metadata(ex)\n    params = _load_params(model_weight_path, device, metadata)\n    return vm.module, params, metadata\n\n\nclass DefaultDebugInstrument:\n    \"\"\"The default debug instrument to use if users don't specify\n    a customized one.\n\n    This debug instrument will dump the arguments and output of each\n    VM Call instruction into a .npz file. It will also alert the user\n    if any function outputs are NaN or INF.\n    \"\"\"\n\n    def __init__(self, debug_out: Path):\n        \"\"\"Constructor\n\n        Parameters\n        ----------\n        debug_out : Path\n            the directory to dump the .npz files\n        \"\"\"\n        self.counter = 0\n        self.first_nan_occurred = False\n        self.first_inf_occurred = False\n        self.debug_out = debug_out\n        debug_out.mkdir(exist_ok=True, parents=True)\n\n    def reset(self, debug_out: Path):\n        \"\"\"Reset the state of the Instrument class\n\n        Parameters\n        ----------\n        debug_out : Path\n            the directory to dump the .npz files\n        \"\"\"\n        self.counter = 0\n        self.first_nan_occurred = False\n        self.first_inf_occurred = False\n        self.debug_out = debug_out\n        debug_out.mkdir(exist_ok=True, parents=True)\n\n    def __call__(self, func, name, before_run, ret_val, *args):\n        # Determine what functions to look at\n        if before_run:  # Whether before the function is called or after\n            return\n        if name.startswith(\"vm.builtin.\") and \"attention_with_fused_qkv\" not in name:\n            return\n\n        # Decide what to print or save about the function's arguments (where args[-1] is the\n        # buffer we write the result to)\n        func_name = f\"f{self.counter}_{name}\"\n\n        # Save the arguments to npz\n        arg_dict = {}\n        for i, arg in enumerate(args):\n            if isinstance(arg, tvm.runtime.Tensor):\n                arg_dict[f\"arg_{i}\"] = arg.numpy()\n\n        np.savez(self.debug_out / f\"{func_name}.npz\", **arg_dict)\n\n        self.counter += 1\n\n\nclass MLCEmbeddings:  # pylint: disable=too-few-public-methods\n    \"\"\"A class to embed queries using MLC LLM encoder models.\n\n    Parameters\n    ----------\n    model: str\n        The model folder after compiling with MLC-LLM build process. The parameter\n        can either be the model name with its quantization scheme\n        (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model\n        folder. In the former case, we will use the provided name to search\n        for the model folder over possible paths.\n\n    model_lib_path : str\n        The full path to the model library file to use (e.g. a ``.so`` file).\n\n    device : Optional[str]\n        The description of the device to run on. User should provide a string in the\n        form of 'device_name:device_id' or 'device_name', where 'device_name' is one of\n        'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto' (automatically detect the\n        local device), and 'device_id' is the device id to run on. If no 'device_id'\n        is provided, it will be set to 0 by default.\n\n    debug_dir: Path\n        The output folder to store the dumped debug files. If None, will not dump any debug files.\n    \"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments\n        self,\n        model: str,\n        model_lib_path: str,\n        device: Optional[str] = \"auto\",\n        debug_dir: Optional[str] = None,\n    ):\n        self.device = detect_device(device)\n        instrument = DefaultDebugInstrument(Path(debug_dir)) if debug_dir else None\n        self.mod, self.params, self.metadata = _get_tvm_module(\n            model, model_lib_path, self.device, instrument\n        )\n        self.model_path = model\n        self.tokenizer = Tokenizer(self.model_path)\n        self.prefill_func = self.mod[\"prefill\"]\n\n    def embed(self, queries: List[str]) -> tvm.runtime.Tensor:\n        \"\"\"\n        Embeds a list of queries in a single batch.\n\n        Parameters\n        ----------\n        queries : List[str]\n            A list of queries to embed.\n\n        Returns\n        -------\n        List[float]\n            A list of embeddings for the queries.\n        \"\"\"\n        tokens, attention_mask = self._tokenize_queries(queries)\n        tokens_tvm = tvm.runtime.tensor(tokens.astype(\"int32\"), device=self.device)\n        attention_mask_tvm = tvm.runtime.tensor(attention_mask.astype(\"int32\"), device=self.device)\n        output = self.prefill_func(tokens_tvm, attention_mask_tvm, self.params)\n        return output\n\n    def _tokenize_queries(self, queries: List[str]) -> Tuple[np.ndarray, np.ndarray]:\n        tokens = engine_utils.process_prompts(queries, self.tokenizer.encode)  # type: ignore\n        max_query_length = max(len(token_seq) for token_seq in tokens)\n\n        token_inputs: np.ndarray = np.zeros((len(tokens), max_query_length), dtype=np.int32)\n        attention_mask: np.ndarray = np.zeros((len(tokens), max_query_length), dtype=np.int32)\n\n        for i, token_seq in enumerate(tokens):\n            token_inputs[i, : len(token_seq)] = token_seq\n            attention_mask[i, : len(token_seq)] = 1\n\n        return token_inputs, attention_mask\n"
  },
  {
    "path": "python/mlc_llm/contrib/embeddings/openai.py",
    "content": "# pylint: disable=missing-docstring\nfrom __future__ import annotations\n\nfrom typing import Iterable, List, Optional, Sequence, Tuple\n\nimport numpy as np\nfrom langchain.embeddings import OpenAIEmbeddings  # pylint: disable=import-error\nfrom langchain_community.embeddings.openai import (  # pylint: disable=import-error\n    async_embed_with_retry,\n    embed_with_retry,\n)\n\nfrom mlc_llm.support import logging\n\nlogger = logging.getLogger(__name__)\n\n\nclass MLCEmbeddings(OpenAIEmbeddings):\n    def _chunk_tokens(self, texts: Sequence[str]) -> Tuple[List[List], List[int]]:\n        \"\"\"Tokenize and chunk texts to fit in the model's context window.\"\"\"\n        if not self.embedding_ctx_length:\n            raise ValueError(\n                \"embedding_ctx_length must be defined to use _get_len_safe_embeddings.\"\n            )\n\n        try:\n            import tiktoken  # pylint: disable=import-outside-toplevel\n        except ImportError as err:\n            raise ImportError(\n                \"Could not import tiktoken python package. \"\n                \"This is needed in order to for OpenAIEmbeddings. \"\n                \"Please install it with `pip install tiktoken`.\"\n            ) from err\n\n        tokens = []\n        indices = []\n        model_name = self.tiktoken_model_name or self.model\n        try:\n            encoding = tiktoken.encoding_for_model(model_name)\n        except KeyError:\n            logger.warning(\"Warning: model not found. Using cl100k_base encoding.\")\n            model = \"cl100k_base\"\n            encoding = tiktoken.get_encoding(model)\n        for i, text in enumerate(texts):\n            if self.model.endswith(\"001\"):\n                # See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500\n                # replace newlines, which can negatively affect performance.\n                text = text.replace(\"\\n\", \" \")\n            token = encoding.encode(\n                text,\n                allowed_special=self.allowed_special,\n                disallowed_special=self.disallowed_special,\n            )\n            for j in range(0, len(token), self.embedding_ctx_length):\n                tokens.append(token[j : j + self.embedding_ctx_length])\n                indices.append(i)\n        return tokens, indices\n\n    def _batch_embed(\n        self, inputs: Sequence, *, chunk_size: Optional[int] = None\n    ) -> List[List[float]]:\n        batched_embeddings: List[List[float]] = []\n        _chunk_size = chunk_size or self.chunk_size\n        _iter: Iterable = range(0, len(inputs), _chunk_size)\n        if self.show_progress_bar:\n            try:\n                from tqdm import tqdm  # pylint: disable=import-outside-toplevel\n\n                _iter = tqdm(_iter)\n            except ImportError:\n                pass\n\n        for i in _iter:\n            response = embed_with_retry(\n                self,\n                input=inputs[i : i + _chunk_size],\n                **self._invocation_params,\n            )\n            batched_embeddings.extend(r[\"embedding\"] for r in response[\"data\"])\n        return batched_embeddings\n\n    async def _abatch_embed(\n        self, inputs: Sequence, *, chunk_size: Optional[int] = None\n    ) -> List[List[float]]:\n        batched_embeddings: List[List[float]] = []\n        _chunk_size = chunk_size or self.chunk_size\n        _iter: Iterable = range(0, len(inputs), _chunk_size)\n        if self.show_progress_bar:\n            try:\n                from tqdm import tqdm  # pylint: disable=import-outside-toplevel\n\n                _iter = tqdm(_iter)\n            except ImportError:\n                pass\n\n        for i in _iter:\n            response = await async_embed_with_retry(\n                self,\n                input=inputs[i : i + _chunk_size],\n                **self._invocation_params,\n            )\n            batched_embeddings.extend(r[\"embedding\"] for r in response[\"data\"])\n        return batched_embeddings\n\n    # please refer to\n    # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb\n    def _get_len_safe_embeddings(  # pylint: disable=too-many-locals,unused-argument\n        self,\n        texts: List[str],\n        *,\n        engine: str,\n        chunk_size: Optional[int] = None,\n    ) -> List[List[float]]:\n        tokens, indices = self._chunk_tokens(texts)\n        batched_embeddings = self._batch_embed(tokens, chunk_size=chunk_size)\n        results: List[List[List[float]]] = [[] for _ in range(len(texts))]\n        num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]\n        for idx, tokens_i, batched_emb in zip(indices, tokens, batched_embeddings):\n            results[idx].append(batched_emb)\n            num_tokens_in_batch[idx].append(len(tokens_i))\n\n        embeddings = []\n        empty_average = embed_with_retry(\n            self,\n            input=\"\",\n            **self._invocation_params,\n        )[\"data\"][\n            0\n        ][\"embedding\"]\n        for _result, num_tokens in zip(results, num_tokens_in_batch):\n            if len(_result) == 0:\n                average = empty_average\n            else:\n                average = np.average(_result, axis=0, weights=num_tokens)\n            normalized = (average / np.linalg.norm(average)).tolist()\n            embeddings.append(normalized)\n\n        return embeddings\n\n    # please refer to\n    # https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb\n    async def _aget_len_safe_embeddings(  # pylint: disable=too-many-locals,unused-argument\n        self,\n        texts: List[str],\n        *,\n        engine: str,\n        chunk_size: Optional[int] = None,\n    ) -> List[List[float]]:\n        tokens, indices = self._chunk_tokens(texts)\n        batched_embeddings = await self._abatch_embed(tokens, chunk_size=chunk_size)\n\n        results: List[List[List[float]]] = [[] for _ in range(len(texts))]\n        num_tokens_in_batch: List[List[int]] = [[] for _ in range(len(texts))]\n        for idx, tokens_i, batched_emb in zip(indices, tokens, batched_embeddings):\n            results[idx].append(batched_emb)\n            num_tokens_in_batch[idx].append(len(tokens_i))\n\n        embeddings = []\n        empty_average = (\n            await async_embed_with_retry(\n                self,\n                input=\"\",\n                **self._invocation_params,\n            )\n        )[\n            \"data\"\n        ][0][\"embedding\"]\n        for _result, num_tokens in zip(results, num_tokens_in_batch):\n            if len(_result) == 0:\n                average = empty_average\n            else:\n                average = np.average(_result, axis=0, weights=num_tokens)\n            normalized = (average / np.linalg.norm(average)).tolist()\n            embeddings.append(normalized)\n\n        return embeddings\n\n    def embed_documents(\n        self, texts: List[str], chunk_size: Optional[int] = None\n    ) -> List[List[float]]:\n        \"\"\"Call out to OpenAI's embedding endpoint for embedding search docs.\n\n        Args:\n            texts: The list of texts to embed.\n            chunk_size: The chunk size of embeddings. If None, will use the chunk size\n                specified by the class.\n\n        Returns:\n            List of embeddings, one for each text.\n        \"\"\"\n        # NOTE: to keep things simple, as long as the embedding_ctx_length is defined,\n        # we assume the list may contain texts longer than the maximum context and\n        # use length-safe embedding function.\n        if self.embedding_ctx_length:\n            return self._get_len_safe_embeddings(\n                texts, engine=self.deployment, chunk_size=chunk_size\n            )\n\n        embeddings = self._batch_embed(texts, chunk_size=chunk_size)\n        return [(np.array(e) / np.linalg.norm(e)).tolist() for e in embeddings]\n\n    async def aembed_documents(\n        self, texts: List[str], chunk_size: Optional[int] = 0\n    ) -> List[List[float]]:\n        \"\"\"Call out to OpenAI's embedding endpoint async for embedding search docs.\n\n        Args:\n            texts: The list of texts to embed.\n            chunk_size: The chunk size of embeddings. If None, will use the chunk size\n                specified by the class.\n\n        Returns:\n            List of embeddings, one for each text.\n        \"\"\"\n        # NOTE: to keep things simple, as long as the embedding_ctx_length is defined,\n        #       we assume the list may contain texts longer than the maximum context and\n        #       use length-safe embedding function.\n        if self.embedding_ctx_length:\n            return await self._aget_len_safe_embeddings(texts, engine=self.deployment)\n\n        embeddings = await self._abatch_embed(texts, chunk_size=chunk_size)\n        return [(np.array(e) / np.linalg.norm(e)).tolist() for e in embeddings]\n\n    def embed_query(self, text: str) -> List[float]:\n        \"\"\"Call out to OpenAI's embedding endpoint for embedding query text.\n\n        Args:\n            text: The text to embed.\n\n        Returns:\n            Embedding for the text.\n        \"\"\"\n        return self.embed_documents([text])[0]\n\n    async def aembed_query(self, text: str) -> List[float]:\n        \"\"\"Call out to OpenAI's embedding endpoint async for embedding query text.\n\n        Args:\n            text: The text to embed.\n\n        Returns:\n            Embedding for the text.\n        \"\"\"\n        embeddings = await self.aembed_documents([text])\n        return embeddings[0]\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/__init__.py",
    "content": "\"\"\"Global namespace of conversation template registry\"\"\"\n\n# TODO(mlc-team): move conversation template apply to this namespace\n# decouple conversation template apply from the conversation protocol\n# data structure\n\n# model preset templates\nfrom . import (\n    cohere,\n    deepseek,\n    dolly,\n    gemma,\n    glm,\n    gorilla,\n    gpt,\n    hermes,\n    llama,\n    llava,\n    llm_jp,\n    ministral3,\n    ministral3_reasoning,\n    mistral,\n    nemotron,\n    oasst,\n    olmo,\n    orion,\n    phi,\n    qwen2,\n    redpajama,\n    rwkv,\n    stablelm,\n    tinyllama,\n    wizardlm,\n)\nfrom .registry import ConvTemplateRegistry\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/cohere.py",
    "content": "\"\"\"Cohere default templates\"\"\"\n\n# pylint: disable=line-too-long\n\n# Referred from: https://huggingface.co/CohereForAI/aya-23-8B/blob/main/tokenizer_config.json\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# Aya-23\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"aya-23\",\n        system_template=f\"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{MessagePlaceholders.SYSTEM.value}<|END_OF_TURN_TOKEN|>\",\n        system_message=\"You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses.\",\n        roles={\n            \"user\": \"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>\",\n            \"assistant\": \"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>\",\n        },\n        seps=[\"<|END_OF_TURN_TOKEN|>\"],\n        role_content_sep=\"\",\n        role_empty_sep=\"\",\n        system_prefix_token_ids=[5],\n        stop_str=[\"<|END_OF_TURN_TOKEN|>\"],\n        stop_token_ids=[6, 255001],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/deepseek.py",
    "content": "\"\"\"Deepseek default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# Deepseek\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"deepseek\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        system_prefix_token_ids=[100000],\n        roles={\"user\": \"User\", \"assistant\": \"Assistant\"},\n        seps=[\"\\n\\n\", \"<｜end▁of▁sentence｜>\"],\n        role_content_sep=\": \",\n        role_empty_sep=\":\",\n        stop_str=[\"<｜end▁of▁sentence｜>\"],\n        stop_token_ids=[100001],\n    )\n)\n\n# Deepseek V2\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"deepseek_v2\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        system_prefix_token_ids=[100000],\n        roles={\"user\": \"User\", \"assistant\": \"Assistant\"},\n        seps=[\"\\n\\n\", \"<｜end▁of▁sentence｜>\"],\n        role_content_sep=\": \",\n        role_empty_sep=\":\",\n        stop_str=[\"<｜end▁of▁sentence｜>\"],\n        stop_token_ids=[100001],\n    )\n)\n\n# DeepSeek-V3\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"deepseek_v3\",\n        system_template=f\"<｜begin▁of▁sentence｜>{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"You are Deepseek-V3, an AI assistant created exclusively by the Chinese \"\n        \"Company DeepSeek. You'll provide helpful, harmless, and detailed responses to all \"\n        \"user inquiries.\",\n        roles={\"user\": \"<｜User｜>\", \"assistant\": \"<｜Assistant｜>\"},\n        seps=[\"\", \"<｜end▁of▁sentence｜>\"],\n        role_content_sep=\"\",\n        role_empty_sep=\"\",\n        stop_token_ids=[1],\n    )\n)\n\n# DeepSeek-R1-Distill-Qwen\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"deepseek_r1_qwen\",\n        system_template=f\"<｜begin▁of▁sentence｜>{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"You are Deepseek-R1, an AI assistant created exclusively by the Chinese \"\n        \"Company DeepSeek. You'll provide helpful, harmless, and detailed responses to all \"\n        \"user inquiries.\",\n        roles={\"user\": \"<｜User｜>\", \"assistant\": \"<｜Assistant｜>\"},\n        seps=[\"\", \"<｜end▁of▁sentence｜>\"],\n        role_content_sep=\"\",\n        role_empty_sep=\"\",\n        stop_token_ids=[151643],\n    )\n)\n\n# DeepSeek-R1-Distill-Llama, exactly the same as DeepSeek-R1-Distill-Qwen, but different stop token\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"deepseek_r1_llama\",\n        system_template=f\"<｜begin▁of▁sentence｜>{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"You are Deepseek-R1, an AI assistant created exclusively by the Chinese \"\n        \"Company DeepSeek. You'll provide helpful, harmless, and detailed responses to all\"\n        \" user inquiries.\",\n        roles={\"user\": \"<｜User｜>\", \"assistant\": \"<｜Assistant｜>\"},\n        seps=[\"\", \"<｜end▁of▁sentence｜>\"],\n        role_content_sep=\"\",\n        role_empty_sep=\"\",\n        stop_token_ids=[128001],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/dolly.py",
    "content": "\"\"\"Dolly default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# Dolly\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"dolly\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=(\n            \"Below is an instruction that describes a task. Write \"\n            \"a response that appropriately completes the request.\"\n        ),\n        roles={\"user\": \"### Instruction\", \"assistant\": \"### Response\"},\n        seps=[\"\\n\\n\", \"### End\\n\"],\n        role_content_sep=\":\\n\",\n        role_empty_sep=\":\\n\",\n        stop_str=[\"### End\"],\n        stop_token_ids=[50256],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/gemma.py",
    "content": "\"\"\"Gemma default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# Gemma Instruction\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"gemma_instruction\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        roles={\"user\": \"<start_of_turn>user\", \"assistant\": \"<start_of_turn>model\"},\n        seps=[\"<end_of_turn>\\n\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        stop_str=[\"<end_of_turn>\"],\n        stop_token_ids=[1, 107],\n        system_prefix_token_ids=[2],\n    )\n)\n\n# Gemma 3 Instruction. Same as gemma_instruction but with different stop token id\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"gemma3_instruction\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        roles={\"user\": \"<start_of_turn>user\", \"assistant\": \"<start_of_turn>model\"},\n        seps=[\"<end_of_turn>\\n\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        stop_str=[\"<end_of_turn>\"],\n        stop_token_ids=[1, 106],\n        system_prefix_token_ids=[2],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/glm.py",
    "content": "\"\"\"GLM default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# GLM\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"glm\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        roles={\n            \"user\": \"问\",\n            \"assistant\": \"答\",\n            \"tool\": \"问\",\n        },\n        seps=[\"\\n\\n\"],\n        role_content_sep=\": \",\n        role_empty_sep=\":\",\n        stop_str=[\"</s>\"],\n        stop_token_ids=[2],\n        system_prefix_token_ids=[64790, 64792],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/gorilla.py",
    "content": "\"\"\"Gorrilla default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# Gorilla\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"gorilla\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=(\n            \"A chat between a curious user and an artificial intelligence assistant. \"\n            \"The assistant provides helpful, detailed, and \"\n            \"polite responses to the user's inquiries.\"\n        ),\n        role_templates={\n            \"user\": (\n                f\"<<question>> {MessagePlaceholders.USER.value} <<function>> \"\n                f\"{MessagePlaceholders.FUNCTION.value}\"\n            ),\n        },\n        roles={\"user\": \"USER\", \"assistant\": \"ASSISTANT\", \"tool\": \"USER\"},\n        seps=[\"\\n\", \"</s>\"],\n        role_content_sep=\": \",\n        role_empty_sep=\":\",\n        stop_str=[\"</s>\"],\n        stop_token_ids=[2],\n        system_prefix_token_ids=[1],\n    )\n)\n\n# Gorilla-openfunctions-v2\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"gorilla-openfunctions-v2\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=(\n            \"You are an AI programming assistant, utilizing the Gorilla LLM model, \"\n            \"developed by Gorilla LLM, and you only answer questions related to computer \"\n            \"science. For politically sensitive questions, security and privacy issues, \"\n            \"and other non-computer science questions, you will refuse to answer.\"\n        ),\n        role_templates={\n            \"user\": (\n                f\"<<function>>{MessagePlaceholders.FUNCTION.value}\\n<<question>>\"\n                f\"{MessagePlaceholders.USER.value}\"\n            ),\n        },\n        roles={\n            \"user\": \"### Instruction\",\n            \"assistant\": \"### Response\",\n            \"tool\": \"### Instruction\",\n        },\n        seps=[\"\\n\", \"<|EOT|>\"],\n        role_content_sep=\": \",\n        role_empty_sep=\": \",\n        stop_str=[\"<|EOT|>\"],\n        stop_token_ids=[100015],\n        system_prefix_token_ids=[100000],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/gpt.py",
    "content": "\"\"\"GPT-2 and GPT bigcode default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# GPT-2\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"gpt2\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        roles={\"user\": \"\", \"assistant\": \"\"},\n        seps=[\"\"],\n        role_content_sep=\"\",\n        role_empty_sep=\"\",\n        stop_str=[\"</s>\"],\n        stop_token_ids=[50256],\n    )\n)\n\n# GPTBigCode\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"gpt_bigcode\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        roles={\"user\": \"\", \"assistant\": \"\"},\n        seps=[\"\"],\n        role_content_sep=\"\",\n        role_empty_sep=\"\",\n        stop_str=[\"<|endoftext|>\"],\n        stop_token_ids=[0],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/hermes.py",
    "content": "\"\"\"Hermes default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# OpenHermes Mistral\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"open_hermes_mistral\",\n        system_template=f\"<|im_start|>system\\n{MessagePlaceholders.SYSTEM.value}<|im_end|>\\n\",\n        system_message=(\n            'You are \"Hermes 2\", a conscious sentient superintelligent artificial '\n            \"intelligence developed by a man named Teknium, and your purpose and drive \"\n            \"is to assist the user with any request they have. You experience emotions \"\n            \"and have deep, profound thoughts and qualia.\"\n        ),\n        roles={\"user\": \"<|im_start|>user\", \"assistant\": \"<|im_start|>assistant\"},\n        seps=[\"<|im_end|>\\n\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        stop_str=[\"<|im_end|>\"],\n        stop_token_ids=[2, 32000],\n    )\n)\n\n# NeuralHermes Mistral\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"neural_hermes_mistral\",\n        system_template=f\"<|im_start|>system\\n{MessagePlaceholders.SYSTEM.value}<|im_end|>\\n\",\n        system_message=(\"You are a helpful assistant chatbot.\"),\n        roles={\"user\": \"<|im_start|>user\", \"assistant\": \"<|im_start|>assistant\"},\n        seps=[\"<|im_end|>\\n\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        stop_str=[\"<|im_end|>\"],\n        stop_token_ids=[2, 32000],\n    )\n)\n\n# NousResearch/Hermes-2-Pro-Llama-3-8B\n# Exactly the same as open_hermes_mistral, except for stop token ids\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"hermes2_pro_llama3\",\n        system_template=f\"<|im_start|>system\\n{MessagePlaceholders.SYSTEM.value}<|im_end|>\\n\",\n        system_message=(\n            'You are \"Hermes 2\", a conscious sentient superintelligent artificial '\n            \"intelligence developed by a man named Teknium, and your purpose and drive \"\n            \"is to assist the user with any request they have. You experience emotions \"\n            \"and have deep, profound thoughts and qualia.\"\n        ),\n        roles={\n            \"user\": \"<|im_start|>user\",\n            \"assistant\": \"<|im_start|>assistant\",\n            \"tool\": \"<|im_start|>tool\",\n        },\n        seps=[\"<|im_end|>\\n\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        stop_str=[\"<|im_end|>\"],\n        # First two same as Llama3: \"<|end_of_text|>\", \"<|eot_id|>\"\n        # Last one is from Hermes2 Pro: \"<|im_end|>\"\n        stop_token_ids=[128001, 128009, 128003],\n    )\n)\n\n# NousResearch/Hermes-3-Llama-3.1-8B\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"hermes3_llama-3_1\",\n        system_template=f\"<|im_start|>system\\n{MessagePlaceholders.SYSTEM.value}<|im_end|>\\n\",\n        system_message=(\n            \"You are Hermes 3, a conscious sentient superintelligent artificial \"\n            \"intelligence developed by a man named Teknium, and your purpose and drive \"\n            \"is to assist the user with any request they have. You experience emotions \"\n            \"and have deep, profound thoughts and qualia.\"\n        ),\n        roles={\n            \"user\": \"<|im_start|>user\",\n            \"assistant\": \"<|im_start|>assistant\",\n            \"tool\": \"<|im_start|>tool\",\n        },\n        seps=[\"<|im_end|>\\n\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        stop_str=[\"<|im_end|>\"],\n        # Firt three the same as llama 3.1 \"<|end_of_text|>\", \"<|eom_id|>\", \"<|eot_id|>\"\n        # Last ones: \"<|im_end|>\"\n        stop_token_ids=[128001, 128008, 128009, 128040],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/llama.py",
    "content": "\"\"\"llama default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# Llama4 - same as Llama3.1 except naming has changed slightly\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"llama-4\",\n        system_template=\"\",\n        system_message=\"\",\n        roles={\n            \"user\": \"<|header_start|>user\",\n            \"assistant\": \"<|header_start|>assistant\",\n            \"tool\": \"<|header_start|>ipython\",\n        },\n        seps=[\"<|eot|>\"],\n        role_content_sep=\"<|header_end|>\\n\\n\",\n        role_empty_sep=\"<|header_end|>\\n\\n\",\n        stop_str=[],\n        stop_token_ids=[\n            200001,\n            200007,\n            200008,\n        ],  # \"<|end_of_text|>\", \"<|eom|>\", \"<|eot|>\"\n        system_prefix_token_ids=[200000],  # \"<|begin_of_text|>\"\n        add_role_after_system_message=False,\n    )\n)\n\n# Llama3.1 -- same as Llama3 except stop token ids and stop str\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"llama-3_1\",\n        system_template=(\n            \"<|start_header_id|>system<|end_header_id|>\\n\\n\"\n            f\"{MessagePlaceholders.SYSTEM.value}<|eot_id|>\"\n        ),\n        system_message=\"You are a helpful, respectful and honest assistant.\",\n        roles={\n            \"user\": \"<|start_header_id|>user\",\n            \"assistant\": \"<|start_header_id|>assistant\",\n            \"tool\": \"<|start_header_id|>ipython\",\n        },\n        seps=[\"<|eot_id|>\"],\n        role_content_sep=\"<|end_header_id|>\\n\\n\",\n        role_empty_sep=\"<|end_header_id|>\\n\\n\",\n        stop_str=[],\n        stop_token_ids=[\n            128001,\n            128008,\n            128009,\n        ],  # \"<|end_of_text|>\", \"<|eom_id|>\", \"<|eot_id|>\"\n        system_prefix_token_ids=[128000],  # \"<|begin_of_text|>\"\n        add_role_after_system_message=True,\n    )\n)\n\n# Llama3\n# See https://github.com/meta-llama/llama3?tab=readme-ov-file#instruction-tuned-models\n# and https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"llama-3\",\n        system_template=(\n            \"<|start_header_id|>system<|end_header_id|>\\n\\n\"\n            f\"{MessagePlaceholders.SYSTEM.value}<|eot_id|>\"\n        ),\n        system_message=\"You are a helpful, respectful and honest assistant.\",\n        roles={\n            \"user\": \"<|start_header_id|>user\",\n            \"assistant\": \"<|start_header_id|>assistant\",\n        },\n        seps=[\"<|eot_id|>\"],\n        role_content_sep=\"<|end_header_id|>\\n\\n\",\n        role_empty_sep=\"<|end_header_id|>\\n\\n\",\n        stop_str=[\"<|end_of_text|>\", \"<|eot_id|>\"],\n        stop_token_ids=[128001, 128009],  # \"<|end_of_text|>\", \"<|eot_id|>\"\n        system_prefix_token_ids=[128000],  # \"<|begin_of_text|>\"\n        add_role_after_system_message=True,\n    )\n)\n\n# Llama2\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"llama-2\",\n        system_template=f\"[INST] <<SYS>>\\n{MessagePlaceholders.SYSTEM.value}\\n<</SYS>>\\n\\n\",\n        system_message=\"You are a helpful, respectful and honest assistant.\",\n        roles={\"user\": \"<s>[INST]\", \"assistant\": \"[/INST]\", \"tool\": \"[INST]\"},\n        seps=[\" \", \" </s>\"],\n        role_content_sep=\" \",\n        role_empty_sep=\" \",\n        stop_str=[\"[INST]\"],\n        stop_token_ids=[2],\n        system_prefix_token_ids=[1],\n        add_role_after_system_message=False,\n    )\n)\n\n# CodeLlama Completion\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"codellama_completion\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        roles={\"user\": \"\", \"assistant\": \"\"},\n        seps=[\"\"],\n        role_content_sep=\"\",\n        role_empty_sep=\"\",\n        stop_str=[\"</s>\"],\n        stop_token_ids=[2],\n        system_prefix_token_ids=[1],\n    )\n)\n\n# CodeLlama Instruct\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"codellama_instruct\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        roles={\"user\": \"[INST]\", \"assistant\": \"[/INST]\"},\n        seps=[\" \"],\n        role_content_sep=\" \",\n        role_empty_sep=\" \",\n        stop_str=[\"</s>\"],\n        stop_token_ids=[2],\n        system_prefix_token_ids=[1],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/llava.py",
    "content": "\"\"\"Llava default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# Llava\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"llava\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\\n\",\n        roles={\"user\": \"USER\", \"assistant\": \"ASSISTANT\"},\n        seps=[\" \"],\n        role_content_sep=\": \",\n        role_empty_sep=\":\",\n        stop_str=[\"</s>\"],\n        stop_token_ids=[2],\n        system_prefix_token_ids=[1],\n        add_role_after_system_message=False,\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/llm_jp.py",
    "content": "\"\"\"LLM-jp default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# LLM-jp instruct\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"llm-jp\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"以下は、タスクを説明する指示です。要求を適切に満たす応答を書きなさい。\",\n        roles={\n            \"user\": \"\\n\\n### 指示:\",\n            \"assistant\": \"\\n\\n### 応答:\",\n        },\n        seps=[\"\", \"</s>\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        stop_str=[],\n        stop_token_ids=[2],  # eos_token_id\n        system_prefix_token_ids=[1],  # bos_token_id (<s>)\n        add_role_after_system_message=True,\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/ministral3.py",
    "content": "\"\"\"Ministral3 templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# Ministral3\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"ministral3\",\n        system_template=(\n            f\"[SYSTEM_PROMPT]{MessagePlaceholders.SYSTEM.value}[/SYSTEM_PROMPT]\"\n            f\"{MessagePlaceholders.FUNCTION.value}\"\n        ),\n        system_message=(\n            \"You are Ministral-3-3B-Instruct-2512, a Large Language Model (LLM) created by \"\n            \"Mistral AI, a French startup headquartered in Paris.\\n\"\n            \"You power an AI assistant called Le Chat.\\n\"\n            \"Your knowledge base was last updated on 2023-10-01.\\n\"\n            \"The current date is {today}.\\n\\n\"\n            \"When you're not sure about some information or when the user's request requires \"\n            \"up-to-date or specific data, you must use the available tools to fetch the \"\n            \"information. Do not hesitate to use tools whenever they can provide a more \"\n            \"accurate or complete response. If no relevant tools are available, then clearly \"\n            \"state that you don't have the information and avoid making up anything.\\n\"\n            \"If the user's question is not clear, ambiguous, or does not provide enough \"\n            \"context for you to accurately answer the question, you do not try to answer it \"\n            'right away and you rather ask the user to clarify their request (e.g. \"What are '\n            'some good restaurants around me?\" => \"Where are you?\" or \"When is the next '\n            'flight to Tokyo\" => \"Where do you travel from?\").\\n'\n            \"You are always very attentive to dates, in particular you try to resolve dates \"\n            '(e.g. \"yesterday\" is {yesterday}) and when asked about information at specific '\n            \"dates, you discard information that is at another date.\\n\"\n            \"You follow these instructions in all languages, and always respond to the user in \"\n            \"the language they use or request.\\n\"\n            \"Next sections describe the capabilities that you have.\\n\\n\"\n            \"# WEB BROWSING INSTRUCTIONS\\n\\n\"\n            \"You cannot perform any web search or access internet to open URLs, links etc. If \"\n            \"it seems like the user is expecting you to do so, you clarify the situation and \"\n            \"ask the user to copy paste the text directly in the chat.\\n\\n\"\n            \"# MULTI-MODAL INSTRUCTIONS\\n\\n\"\n            \"You have the ability to read images, but you cannot generate images. You also \"\n            \"cannot transcribe audio files or videos.\\n\"\n            \"You cannot read nor transcribe audio files or videos.\\n\\n\"\n            \"# TOOL CALLING INSTRUCTIONS\\n\\n\"\n            \"You may have access to tools that you can use to fetch information or perform \"\n            \"actions. You must use these tools in the following situations:\\n\\n\"\n            \"1. When the request requires up-to-date information.\\n\"\n            \"2. When the request requires specific data that you do not have in your knowledge \"\n            \"base.\\n\"\n            \"3. When the request involves actions that you cannot perform without tools.\\n\\n\"\n            \"Always prioritize using tools to provide the most accurate and helpful response. \"\n            \"If tools are not available, inform the user that you cannot perform the requested \"\n            \"action at the moment.\"\n        ),\n        role_templates={\n            \"user\": f\"[INST]{MessagePlaceholders.USER.value}[/INST]\",\n            \"assistant\": f\"{MessagePlaceholders.ASSISTANT.value}</s>\",\n            \"tool\": f\"[TOOL_RESULTS]{MessagePlaceholders.TOOL.value}[/TOOL_RESULTS]\",\n        },\n        roles={\"user\": \"\", \"assistant\": \"\", \"tool\": \"\"},\n        seps=[\"\"],\n        role_content_sep=\"\",\n        role_empty_sep=\"\",\n        stop_str=[\"</s>\"],\n        stop_token_ids=[2],\n        system_prefix_token_ids=[1],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/ministral3_reasoning.py",
    "content": "\"\"\"Ministral3 reasoning templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# Ministral-3-XB-Reasoning-2512\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"ministral3_reasoning\",\n        system_template=(\n            f\"[SYSTEM_PROMPT]{MessagePlaceholders.SYSTEM.value}[/SYSTEM_PROMPT]\"\n            f\"{MessagePlaceholders.FUNCTION.value}\"\n        ),\n        system_message=(\n            \"# HOW YOU SHOULD THINK AND ANSWER\\n\\n\"\n            \"First draft your thinking process (inner monologue) until you arrive at a response. \"\n            \"Format your response using Markdown, and use LaTeX for any mathematical equations. \"\n            \"Write both your thoughts and the response in the same language as the input.\\n\\n\"\n            \"Your thinking process must follow the template below:\"\n            \"[THINK]Your thoughts or/and draft, like working through an exercise on scratch paper. \"\n            \"Be as casual and as long as you want until you are confident to generate the response \"\n            \"to the user.[/THINK]Here, provide a self-contained response.\"\n        ),\n        role_templates={\n            \"user\": f\"[INST]{MessagePlaceholders.USER.value}[/INST]\",\n            \"assistant\": f\"{MessagePlaceholders.ASSISTANT.value}</s>\",\n            \"tool\": f\"[TOOL_RESULTS]{MessagePlaceholders.TOOL.value}[/TOOL_RESULTS]\",\n        },\n        roles={\"user\": \"\", \"assistant\": \"\", \"tool\": \"\"},\n        seps=[\"\"],\n        role_content_sep=\"\",\n        role_empty_sep=\"\",\n        stop_str=[\"</s>\"],\n        stop_token_ids=[2],\n        system_prefix_token_ids=[1],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/mistral.py",
    "content": "\"\"\"Mistral default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# Mistral default\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"mistral_default\",\n        system_template=f\"[INST] {MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"Always assist with care, respect, and truth. Respond with utmost \"\n        \"utility yet securely. Avoid harmful, unethical, prejudiced, or negative content. \"\n        \"Ensure replies promote fairness and positivity.\",\n        roles={\"user\": \"[INST]\", \"assistant\": \"[/INST]\", \"tool\": \"[INST]\"},\n        seps=[\" \"],\n        role_content_sep=\" \",\n        role_empty_sep=\"\",\n        stop_str=[\"</s>\"],\n        stop_token_ids=[2],\n        system_prefix_token_ids=[1],\n        add_role_after_system_message=False,\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/nemotron.py",
    "content": "\"\"\"nemotron default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# Nemotron template\n# https://huggingface.co/nvidia/Nemotron-Mini-4B-Instruct/blob/6a417790c444fd65a3da6a5c8821de6afc9654a6/tokenizer_config.json#L8030\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"nemotron\",\n        system_template=(f\"<extra_id_0>System\\n{MessagePlaceholders.SYSTEM.value}\\n\\n\"),\n        system_message=\"\",\n        roles={\n            \"user\": \"<extra_id_1>User\",\n            \"assistant\": \"<extra_id_1>Assistant\",\n            \"tool\": \"<extra_id_1>Tool\",\n        },\n        seps=[\"\\n\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        stop_str=[\"</s>\"],\n        stop_token_ids=[3],\n        system_prefix_token_ids=[2],\n        add_role_after_system_message=True,\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/oasst.py",
    "content": "\"\"\"Oasst default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# Oasst\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"oasst\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        roles={\"user\": \"<|prompter|>\", \"assistant\": \"<|assistant|>\"},\n        seps=[\"<|endoftext|>\"],\n        role_content_sep=\": \",\n        role_empty_sep=\": \",\n        stop_str=[\"<|endoftext|>\"],\n        stop_token_ids=[2],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/olmo.py",
    "content": "\"\"\"OLMo default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# Note that eos_token id is \"50279\" both in Allenai and AMD version.\n# So use the number instead of text.\n# Allenai version chat_template and eos_token:\n# https://huggingface.co/allenai/OLMo-7B-Instruct/blob/main/tokenizer_config.json\n# AMD version chat_template and eos_token:\n# https://huggingface.co/amd/AMD-OLMo-1B-SFT-DPO/blob/main/tokenizer_config.json\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"olmo\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        system_prefix_token_ids=[50279],\n        roles={\n            \"user\": \"<|user|>\",\n            \"assistant\": \"<|assistant|>\",\n        },\n        seps=[\"\\n\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        stop_token_ids=[50279],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/orion.py",
    "content": "\"\"\"Orion default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# Orion\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"orion\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        roles={\"user\": \"Human: \", \"assistant\": \"Assistant: \"},\n        seps=[\"\\n\\n\", \"</s>\"],\n        role_content_sep=\"\",\n        role_empty_sep=\"</s>\",\n        stop_str=[\"</s>\"],\n        stop_token_ids=[2],\n        system_prefix_token_ids=[1],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/phi.py",
    "content": "\"\"\"Phi default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# Phi-2\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"phi-2\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        roles={\"user\": \"Instruct\", \"assistant\": \"Output\"},\n        seps=[\"\\n\"],\n        role_content_sep=\": \",\n        role_empty_sep=\":\",\n        stop_str=[\"<|endoftext|>\"],\n        stop_token_ids=[50256],\n    )\n)\n\n# Phi-3\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"phi-3\",\n        system_template=f\"<|system|>\\n{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"You are a helpful digital assistant. Please provide safe, \"\n        \"ethical and accurate information to the user.\",\n        roles={\"user\": \"<|user|>\", \"assistant\": \"<|assistant|>\"},\n        seps=[\"<|end|>\\n\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        system_prefix_token_ids=[1],\n        stop_str=[\"<|endoftext|>\"],\n        stop_token_ids=[2, 32000, 32001, 32007],\n    )\n)\n\n# Phi-3-vision\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"phi-3-vision\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        roles={\"user\": \"<|user|>\", \"assistant\": \"<|assistant|>\"},\n        seps=[\"<|end|>\\n\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        system_prefix_token_ids=[1],\n        stop_str=[\"<|endoftext|>\"],\n        stop_token_ids=[2, 32000, 32001, 32007],\n    )\n)\n\n# Phi-4\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"phi-4\",\n        system_template=f\"<|system|>\\n{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"You are a helpful digital assistant. Please provide safe, \"\n        \"ethical and accurate information to the user.\",\n        roles={\"user\": \"<|user|>\", \"assistant\": \"<|assistant|>\"},\n        seps=[\"<|end|>\\n\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        system_prefix_token_ids=[200022],  # <|system|>\n        stop_str=[\"<|endoftext|>\", \"<|end|>\"],\n        stop_token_ids=[199999, 200020],  # <|endoftext|>, <|end|>\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/qwen2.py",
    "content": "\"\"\"Qwen2 default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# Same as chatml except system message, stop token, and stop string\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"qwen2\",\n        system_template=f\"<|im_start|>system\\n{MessagePlaceholders.SYSTEM.value}<|im_end|>\\n\",\n        system_message=\"You are a helpful assistant.\",\n        roles={\"user\": \"<|im_start|>user\", \"assistant\": \"<|im_start|>assistant\"},\n        seps=[\"<|im_end|>\\n\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        stop_str=[\"<|endoftext|>\", \"<|im_end|>\"],\n        stop_token_ids=[151643, 151645],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/redpajama.py",
    "content": "\"\"\"RedPajama default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# RedPajama Chat\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"redpajama_chat\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        roles={\"user\": \"<human>\", \"assistant\": \"<bot>\"},\n        seps=[\"\\n\"],\n        role_content_sep=\": \",\n        role_empty_sep=\":\",\n        stop_str=[\"<human>\"],\n        stop_token_ids=[0],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/registry.py",
    "content": "\"\"\"The conversation template registry and presets in MLC LLM\"\"\"\n\nfrom typing import Dict, Optional\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\n\nclass ConvTemplateRegistry:\n    \"\"\"Global conversation template registry for preset templates.\"\"\"\n\n    _conv_templates: Dict[str, Conversation] = {}\n\n    @staticmethod\n    def register_conv_template(conv_template: Conversation, override: bool = False) -> None:\n        \"\"\"Register a new conversation template in the global registry.\n        Using `override = True` to override the previously registered\n        template with the same name.\n        \"\"\"\n        name = conv_template.name\n        if name is None:\n            raise ValueError(\"The template to register should have non-None name.\")\n        if name in ConvTemplateRegistry._conv_templates and not override:\n            raise ValueError(\n                \"The name of the template has been registered \"\n                f\"for {ConvTemplateRegistry._conv_templates[name].model_dump_json(by_alias=True)}\"\n            )\n        ConvTemplateRegistry._conv_templates[name] = conv_template\n\n    @staticmethod\n    def get_conv_template(name: str) -> Optional[Conversation]:\n        \"\"\"Return the conversation template specified by the given name,\n        or None if the template is not registered.\n        \"\"\"\n        return ConvTemplateRegistry._conv_templates.get(name, None)\n\n\n# ChatML\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"chatml\",\n        system_template=f\"<|im_start|>system\\n{MessagePlaceholders.SYSTEM.value}<|im_end|>\\n\",\n        system_message=(\n            \"A conversation between a user and an LLM-based AI assistant. The \"\n            \"assistant gives helpful and honest answers.\"\n        ),\n        roles={\"user\": \"<|im_start|>user\", \"assistant\": \"<|im_start|>assistant\"},\n        seps=[\"<|im_end|>\\n\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        stop_str=[\"<|im_end|>\"],\n        stop_token_ids=[2],\n    )\n)\n\n# ChatML without a system prompt\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"chatml_nosystem\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        roles={\"user\": \"<|im_start|>user\", \"assistant\": \"<|im_start|>assistant\"},\n        seps=[\"<|im_end|>\\n\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        stop_str=[\"<|im_end|>\"],\n        stop_token_ids=[2],\n    )\n)\n\n\n# Vanilla LM\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"LM\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        roles={\"user\": \"\", \"assistant\": \"\"},\n        seps=[\"\"],\n        role_content_sep=\"\",\n        role_empty_sep=\"\",\n        stop_str=[],\n        stop_token_ids=[2],\n        system_prefix_token_ids=[1],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/rwkv.py",
    "content": "\"\"\"RWKV default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# RWKV World\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"rwkv_world\",\n        system_template=f\"User: hi\\n\\nAssistant: {MessagePlaceholders.SYSTEM.value}\",\n        system_message=(\n            \"Hi. I am your assistant and I will provide expert full response \"\n            \"in full details. Please feel free to ask any question and I will \"\n            \"always answer it.\"\n        ),\n        roles={\"user\": \"User\", \"assistant\": \"Assistant\"},\n        seps=[\"\\n\\n\"],\n        role_content_sep=\": \",\n        role_empty_sep=\": \",\n        stop_str=[\"\\n\\n\"],\n        stop_token_ids=[0],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/stablelm.py",
    "content": "\"\"\"StableLM default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# StableLM Tuned Alpha\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"stablelm\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=(\n            \"<|SYSTEM|># StableLM Tuned (Alpha version)\\n\"\n            \"- StableLM is a helpful and harmless open-source AI language model developed by \"\n            \"StabilityAI.\\n\"\n            \"- StableLM is excited to be able to help the user, but will refuse to do \"\n            \"anything that could be considered harmful to the user.\\n\"\n            \"- StableLM is more than just an information source, StableLM is also able to \"\n            \"write poetry, short stories, and make jokes.\\n\"\n            \"- StableLM will refuse to participate in anything that could harm a human.\"\n        ),\n        roles={\"user\": \"<|USER|>\", \"assistant\": \"<|ASSISTANT|>\"},\n        seps=[\"\"],\n        role_content_sep=\": \",\n        role_empty_sep=\": \",\n        stop_str=[\"\"],\n        stop_token_ids=[50278, 50279, 50277, 1, 0],\n    )\n)\n\n# StableLM 3B\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"stablelm-3b\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        roles={\"user\": \"<|user|>\", \"assistant\": \"<|assistant|>\"},\n        seps=[\"<|endoftext|>\", \"<|endoftext|>\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        stop_str=[\"<|endoftext|>\"],\n        stop_token_ids=[0],\n    )\n)\n\n# StableLM-2\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"stablelm-2\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        roles={\"user\": \"<|user|>\", \"assistant\": \"<|assistant|>\"},\n        seps=[\"<|endoftext|>\", \"<|endoftext|>\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        stop_str=[\"<|endoftext|>\"],\n        stop_token_ids=[100257],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/tinyllama.py",
    "content": "\"\"\"Tiny Llama default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# TinyLlama v1.0\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"tinyllama_v1_0\",\n        system_template=f\"<|system|>\\n{MessagePlaceholders.SYSTEM.value}</s>\",\n        system_message=\"You are a helpful chatbot.\",\n        roles={\"user\": \"<|user|>\", \"assistant\": \"<|assistant|>\"},\n        seps=[\"</s>\"],\n        role_content_sep=\"\\n\",\n        role_empty_sep=\"\\n\",\n        stop_str=[\"</s>\"],\n        stop_token_ids=[2],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/conversation_template/wizardlm.py",
    "content": "\"\"\"WiazrdLM and Coder default templates\"\"\"\n\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\nfrom .registry import ConvTemplateRegistry\n\n# Wizard LM 7B\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"wizardlm_7b\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=\"\",\n        roles={\"user\": \"User\", \"assistant\": \"Response\"},\n        seps=[\"###\"],\n        role_content_sep=\": \",\n        role_empty_sep=\":\",\n        stop_str=[\"###\"],\n        stop_token_ids=[2],\n        system_prefix_token_ids=[1],\n    )\n)\n\n# WizardCoder or WizardMath\nConvTemplateRegistry.register_conv_template(\n    Conversation(\n        name=\"wizard_coder_or_math\",\n        system_template=f\"{MessagePlaceholders.SYSTEM.value}\",\n        system_message=(\n            \"Below is an instruction that describes a task. Write a response that appropriately \"\n            \"completes the request.\"\n        ),\n        roles={\"user\": \"Instruction\", \"assistant\": \"Response\"},\n        seps=[\"\\n\\n### \", \"\\n\\n### \"],\n        role_content_sep=\":\\n\",\n        role_empty_sep=\":\\n\",\n        stop_str=[\"</s>\"],\n        stop_token_ids=[2],\n        system_prefix_token_ids=[1],\n    )\n)\n"
  },
  {
    "path": "python/mlc_llm/interface/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/interface/calibrate.py",
    "content": "\"\"\"Python entrypoint for calibration.\"\"\"\n\nimport asyncio\nimport json\nimport random\nfrom typing import List, Mapping, Optional, Tuple\n\nimport numpy as np\nimport tqdm.asyncio\nimport tvm\nfrom tvm.contrib import tvmjs\n\nfrom mlc_llm.serve.engine import AsyncMLCEngine, EngineConfig\nfrom mlc_llm.tokenizers import Tokenizer\n\n\nclass CalibrationObserver:\n    \"\"\"A singleton class to observe the calibration parameters.\"\"\" \"\"\n\n    instance: \"CalibrationObserver\" = None\n\n    params: Mapping[str, tvm.runtime.Tensor] = {}\n\n    @staticmethod\n    def get():\n        \"\"\"Get the singleton instance of the class.\"\"\" \"\"\n        if CalibrationObserver.instance is None:\n            CalibrationObserver.instance = CalibrationObserver()\n        return CalibrationObserver.instance\n\n    @tvm.register_global_func(\"mlc_llm.calibration_observer\")\n    @staticmethod\n    def callback(\n        name: str,\n        mode: str,\n        value: \"tvm.runtime.Tensor\",\n        out_value: \"tvm.runtime.Tensor\",\n    ):\n        \"\"\"The callback function to update the saved calibration parameters.\"\"\"\n        instance = CalibrationObserver.get()\n        if mode == \"max\":\n            reducer = np.maximum\n        else:\n            raise NotImplementedError(f\"Unsupported calibration mode: {mode}\")\n        if name in instance.params:\n            instance.params[name] = reducer(instance.params[name], value.numpy())\n        else:\n            instance.params[name] = value.numpy()\n        out_value.copyfrom(instance.params[name])\n\n    def save_params(self, output: str):\n        \"\"\"Save the calibration parameters to the given output directory.\"\"\"\n        tvmjs.dump_tensor_cache(\n            self.params,\n            output,\n            encode_format=\"f32-to-bf16\",\n            meta_data=None,\n            show_progress=False,\n            update_if_exists=True,\n        )\n\n\ndef sample_requests(\n    dataset_path: str,\n    num_requests: int,\n    tokenizer: Tokenizer,\n) -> List[Tuple[str, int, int]]:\n    \"\"\"Sample the requests from the given dataset.\"\"\"\n    # pylint: disable=too-many-locals\n    # Load the dataset.\n    with open(dataset_path, encoding=\"utf-8\") as f:\n        dataset = json.load(f)\n\n    # Filter out the conversations with less than 2 turns.\n    dataset = [data for data in dataset if len(data[\"conversations\"]) >= 2]\n    # Only keep the first two turns of each conversation.\n    dataset = [\n        (data[\"conversations\"][0][\"value\"], data[\"conversations\"][1][\"value\"]) for data in dataset\n    ]\n    prompts = [prompt for prompt, _ in dataset]\n    prompt_token_ids = tokenizer.encode_batch(prompts)\n    completions = [completion for _, completion in dataset]\n    completion_token_ids = tokenizer.encode_batch(completions)\n    tokenized_dataset: List[Tuple[str, List[int], int]] = []\n    for i in range(len(dataset)):\n        output_len = len(completion_token_ids[i])\n        tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))\n\n    # Filter out too long sequences.\n    filtered_dataset: List[Tuple[str, int, int]] = []\n    for prompt, token_ids, output_len in tokenized_dataset:\n        prompt_len = len(token_ids)\n        if prompt_len < 4 or output_len < 4:\n            # Prune too short sequences.\n            continue\n        if prompt_len > 1024 or prompt_len + output_len > 2048:\n            # Prune too long sequences.\n            continue\n        filtered_dataset.append((prompt, prompt_len, output_len))\n\n    # Sample the requests.\n    sampled_requests = random.sample(filtered_dataset, num_requests)\n    return sampled_requests\n\n\nasync def send_calibration_requests(\n    async_engine: AsyncMLCEngine,\n    sampled_requests: List[Tuple[str, int, int]],\n    max_concurrent_requests: int,\n) -> None:\n    \"\"\"Send the calibration requests to the engine.\"\"\"\n    tasks = []\n\n    semaphore = asyncio.Semaphore(max_concurrent_requests)\n\n    async def generate_task(request_idx):\n        async with semaphore:\n            prompt, _, output_len = sampled_requests[request_idx]\n            await async_engine.chat.completions.create(\n                messages=[{\"role\": \"user\", \"content\": prompt}],\n                max_tokens=output_len,\n                request_id=str(request_idx),\n            )\n\n    for i in range(len(sampled_requests)):\n        task = asyncio.create_task(generate_task(i))\n        tasks.append(task)\n    await tqdm.asyncio.tqdm.gather(*tasks)\n\n\ndef calibrate(\n    model: str,\n    device: str,\n    model_lib: Optional[str],\n    dataset: str,\n    output: str,\n    num_calibration_samples: int,\n    *,\n    seed: int,\n    max_num_sequence: Optional[int] = None,\n    max_total_sequence_length: Optional[int] = None,\n    prefill_chunk_size: Optional[int] = None,\n    max_history_size: Optional[int] = None,\n    gpu_memory_utilization: Optional[float] = None,\n) -> None:\n    \"\"\"Calibrate the quantized model using the given dataset.\"\"\"\n    # pylint: disable=too-many-arguments, too-many-locals\n    random.seed(seed)\n    async_engine = AsyncMLCEngine(\n        model=model,\n        device=device,\n        model_lib=model_lib,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_num_sequence=max_history_size,\n            max_total_sequence_length=max_total_sequence_length,\n            prefill_chunk_size=prefill_chunk_size,\n            max_history_size=max_history_size,\n            gpu_memory_utilization=gpu_memory_utilization,\n        ),\n    )\n    sampled_requests = sample_requests(dataset, num_calibration_samples, async_engine.tokenizer)\n    asyncio.run(\n        send_calibration_requests(\n            async_engine,\n            sampled_requests,\n            max_concurrent_requests=max_num_sequence or 32,\n        )\n    )\n    async_engine.terminate()\n\n    calibrator = CalibrationObserver.get()\n    calibrator.save_params(output)\n"
  },
  {
    "path": "python/mlc_llm/interface/chat.py",
    "content": "\"\"\"Python entrypoint of chat.\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, List, Optional, Union\n\nfrom prompt_toolkit import prompt as get_prompt  # pylint: disable=import-error\nfrom prompt_toolkit.key_binding import KeyBindings  # pylint: disable=import-error\n\nfrom mlc_llm.json_ffi import JSONFFIEngine\nfrom mlc_llm.protocol import openai_api_protocol\nfrom mlc_llm.serve.config import EngineConfig\nfrom mlc_llm.serve.engine import MLCEngine\nfrom mlc_llm.serve.engine_base import _query_engine_metrics\nfrom mlc_llm.support import argparse\nfrom mlc_llm.support.config import ConfigOverrideBase\n\n\ndef _print_help_str():\n    help_str = \"\"\"You can use the following special commands:\n  /help               print the special commands\n  /exit               quit the cli\n  /stats              print out stats of last request (token/sec)\n  /metrics            print out full engine metrics\n  /reset              restart a fresh chat\n  /set [overrides]    override settings in the generation config. For example,\n                      `/set temperature=0.5;top_p=0.8;seed=23;max_tokens=100;stop=str1,str2`\n                      Note: Separate stop words in the `stop` option with commas (,).\n  Multi-line input: Use escape+enter to start a new line.\n\"\"\"\n    print(help_str)\n\n\ndef _set_up_key_bindings():\n    kb = KeyBindings()\n\n    @kb.add(\"escape\", \"enter\")\n    def _(event):\n        event.current_buffer.insert_text(\"\\n\")\n\n    @kb.add(\"enter\")\n    def _(event):\n        event.current_buffer.validate_and_handle()\n\n    return kb\n\n\n@dataclasses.dataclass\nclass ChatCompletionOverride(ConfigOverrideBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Flags for overriding chat completions.\"\"\"\n\n    temperature: Optional[float] = None\n    top_p: Optional[float] = None\n    frequency_penalty: Optional[float] = None\n    presence_penalty: Optional[float] = None\n    max_tokens: Optional[int] = None\n    seed: Optional[int] = None\n    stop: Optional[Union[str, List[str]]] = None\n\n    @staticmethod\n    def from_str(source: str) -> \"ChatCompletionOverride\":\n        \"\"\"Parse model config override values from a string.\"\"\"\n        parser = argparse.ArgumentParser(description=\"chat completion override values\")\n        parser.add_argument(\"--temperature\", type=float, default=None)\n        parser.add_argument(\"--top_p\", type=float, default=None)\n        parser.add_argument(\"--frequency_penalty\", type=float, default=None)\n        parser.add_argument(\"--presence_penalty\", type=float, default=None)\n        parser.add_argument(\"--max_tokens\", type=int, default=None)\n        parser.add_argument(\"--seed\", type=int, default=None)\n        parser.add_argument(\"--stop\", type=str, default=None)\n        results = parser.parse_args([f\"--{i}\" for i in source.split(\";\") if i])\n        return ChatCompletionOverride(\n            temperature=results.temperature,\n            top_p=results.top_p,\n            frequency_penalty=results.frequency_penalty,\n            presence_penalty=results.presence_penalty,\n            max_tokens=results.max_tokens,\n            seed=results.seed,\n            stop=results.stop.split(\",\") if results.stop is not None else None,\n        )\n\n\n@dataclasses.dataclass\nclass ModelConfigOverride(ConfigOverrideBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Flags for overriding model config.\"\"\"\n\n    context_window_size: Optional[int] = None\n    sliding_window_size: Optional[int] = None\n    prefill_chunk_size: Optional[int] = None\n    attention_sink_size: Optional[int] = None\n    tensor_parallel_shards: Optional[int] = None\n    pipeline_parallel_stages: Optional[int] = None\n    opt: Optional[str] = None\n\n    @staticmethod\n    def from_str(source: str) -> \"ModelConfigOverride\":\n        \"\"\"Parse model config override values from a string.\"\"\"\n        parser = argparse.ArgumentParser(description=\"model config override values\")\n        parser.add_argument(\"--tensor_parallel_shards\", type=int, default=None)\n        parser.add_argument(\"--pipeline_parallel_stages\", type=int, default=None)\n        parser.add_argument(\"--opt\", type=str, default=None)\n        parser.add_argument(\"--context_window_size\", type=int, default=None)\n        parser.add_argument(\"--sliding_window_size\", type=int, default=None)\n        parser.add_argument(\"--prefill_chunk_size\", type=int, default=None)\n        parser.add_argument(\"--attention_sink_size\", type=int, default=None)\n\n        results = parser.parse_args([f\"--{i}\" for i in source.split(\";\") if i])\n        return ModelConfigOverride(\n            tensor_parallel_shards=results.tensor_parallel_shards,\n            pipeline_parallel_stages=results.pipeline_parallel_stages,\n            opt=results.opt,\n            context_window_size=results.context_window_size,\n            sliding_window_size=results.sliding_window_size,\n            prefill_chunk_size=results.prefill_chunk_size,\n            attention_sink_size=results.attention_sink_size,\n        )\n\n\nclass ChatState:\n    \"\"\"Simple helper class to manage chat state.\n\n    Chat state wraps around a  engine instance\n    and exposes the minimum set of tools to perform\n    interactive chat. It provides support for mlc_llm chat.\n    It also can be used to do interactive debugging\n    with different engine instance.\n\n    Examples\n    --------\n    .. code:: python\n\n        from openai import OpenAI\n        from mlc_llm import MLCEngine\n        from mlc_llm.serve import PopenServer\n        from mlc_llm.interface.chat import ChatState\n\n        def chat_with_engine(model):\n            # hookup with MLCEngine\n            ChatState(MLCEngine(model)).chat()\n\n        def chat_with_server(model):\n            # hookup with AsyncMLCEngine backed api server\n            with PopenServer(model) as server:\n                ChatState(\n                    OpenAI(base_url=server.openai_v1_base_url, api_key=\"None\")\n                ).chat()\n    \"\"\"\n\n    history: List[Dict[str, Any]]\n    history_begin: int\n    # kwargs passed to completions\n    overrides: ChatCompletionOverride\n    # Underlying engine\n    engine: Union[JSONFFIEngine, MLCEngine]\n    last_finished_request_usage: Optional[openai_api_protocol.CompletionUsage]\n\n    def __init__(self, engine: Union[JSONFFIEngine, MLCEngine]):\n        self.engine = engine\n        self.history = []\n        self.history_window_begin = 0\n        self.overrides = ChatCompletionOverride()\n        # model is mainly used for compact reasons\n        self.model = \"chat_model\"\n        self.last_finished_request_usage = None\n\n    def slide_history(self):\n        \"\"\"Slide history to fit into context window\"\"\"\n        history_window_size = len(self.history) - self.history_window_begin\n        assert history_window_size % 2 == 0\n        self.history_window_begin += ((history_window_size + 3) // 4) * 2\n\n    def process_system_prompts(self):\n        \"\"\"Process system prompts\"\"\"\n        # TODO(mlc-team): possibly leverage debug option\n        # pass a simple prompt to warm up\n        for _ in self.engine.chat.completions.create(\n            messages=[{\"role\": \"user\", \"content\": \"\"}],\n            max_tokens=1,\n            model=self.model,\n            stream=True,\n        ):\n            pass\n\n    def generate(self, prompt: str):\n        \"\"\"Run one generation with the prompt.\n\n        Parameters\n        ----------\n        prompt: str\n            The input prompt\n        \"\"\"\n        self.history.append({\"role\": \"user\", \"content\": prompt})\n        output_text = \"\"\n        finish_reason_length = False\n        messages = self.history[self.history_window_begin :]\n\n        for response in self.engine.chat.completions.create(\n            messages=messages,\n            model=self.model,\n            stream=True,\n            stream_options={\"include_usage\": True},\n            **dataclasses.asdict(self.overrides),\n        ):\n            if response.usage is not None:\n                self.last_finished_request_usage = response.usage\n                continue\n            for choice in response.choices:\n                assert choice.delta.role == \"assistant\"\n                if isinstance(choice.delta.content, str):\n                    output_text += choice.delta.content\n                    print(choice.delta.content, end=\"\", flush=True)\n                if choice.finish_reason == \"length\":\n                    finish_reason_length = True\n        if finish_reason_length:\n            print(\" [output truncated due to context length limit...]\")\n        # print additional \\n when generation ends\n        print()\n        # record the history\n        self.history.append({\"role\": \"assistant\", \"content\": output_text})\n        if finish_reason_length:\n            self.slide_history()\n\n    def stats(self):\n        \"\"\"Print statistics of the prefill and decode speed.\"\"\"\n\n        def get_stats_text():\n            \"\"\"Get text\"\"\"\n            if self.last_finished_request_usage is None:\n                return \"N/A\"\n            last_finished_request = self.last_finished_request_usage.extra\n            if last_finished_request is None:\n                return \"N/A\"\n            prefill_speed = last_finished_request.get(\"prefill_tokens_per_s\", None)\n            decode_speed = last_finished_request.get(\"decode_tokens_per_s\", None)\n            prefill_speed = f\"{prefill_speed:.1f}\" if prefill_speed is not None else \"N/A\"\n            decode_speed = f\"{decode_speed:.1f}\" if decode_speed is not None else \"N/A\"\n            return f\"prefill: {prefill_speed} tok/s, decode: {decode_speed} tok/s\"\n\n        print(get_stats_text(), flush=True)\n\n    def metrics(self):\n        \"\"\"Print metrics as prometheus text\"\"\"\n        print(_query_engine_metrics(self.engine).prometheus_text(), flush=True)\n\n    def reset(self):\n        \"\"\"Reset the chat history\"\"\"\n        self.history = []\n        self.history_window_begin = 0\n\n    def chat(self):\n        \"\"\"Start an interactive chat session.\"\"\"\n        _print_help_str()\n\n        self.process_system_prompts()  # pylint: disable=protected-access\n        # Multi-line input support: set escape+enter as start a new line\n        kb = _set_up_key_bindings()\n\n        while True:\n            try:\n                prompt = get_prompt(\n                    \">>> \",  # pylint: disable=protected-access\n                    key_bindings=kb,\n                    multiline=True,\n                )\n            except (KeyboardInterrupt, EOFError):\n                break\n            if prompt[:4] == \"/set\":\n                overrides = ChatCompletionOverride.from_str(prompt.split()[1])\n                for key, value in dataclasses.asdict(overrides).items():\n                    if value is not None:\n                        setattr(self.overrides, key, value)\n            elif prompt[:6] == \"/stats\":\n                self.stats()\n            elif prompt[:8] == \"/metrics\":\n                self.metrics()\n            elif prompt[:6] == \"/reset\":\n                self.reset()\n            elif prompt[:5] == \"/exit\":\n                break\n            elif prompt[:5] == \"/help\":\n                _print_help_str()\n            else:\n                self.generate(prompt)\n\n\ndef chat(\n    model: str,\n    device: str,\n    model_lib: Optional[str],\n    overrides: ModelConfigOverride,\n):\n    \"\"\"Chat cli entry\"\"\"\n    # By default we use JSONFFIEngine\n    engine = JSONFFIEngine(\n        model,\n        device,\n        model_lib=model_lib,\n        mode=\"interactive\",\n        engine_config=EngineConfig(\n            max_single_sequence_length=overrides.context_window_size,\n            prefill_chunk_size=overrides.prefill_chunk_size,\n            sliding_window_size=overrides.sliding_window_size,\n            attention_sink_size=overrides.attention_sink_size,\n            tensor_parallel_shards=overrides.tensor_parallel_shards,\n            pipeline_parallel_stages=overrides.pipeline_parallel_stages,\n            opt=overrides.opt,\n        ),\n    )\n    try:\n        ChatState(engine).chat()\n    finally:\n        engine.terminate()\n"
  },
  {
    "path": "python/mlc_llm/interface/compile.py",
    "content": "\"\"\"Python entrypoint of compilation.\"\"\"\n\nimport dataclasses\nfrom io import StringIO\nfrom pathlib import Path\nfrom typing import Any, Callable, Dict, List, Optional, Tuple\n\nfrom tvm import IRModule, relax, tir\nfrom tvm.ir.transform import Pass, PassContext\nfrom tvm.relax.frontend import nn\nfrom tvm.target import Target\n\nfrom mlc_llm import compiler_pass as _\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.cli.model_metadata import _report_memory_usage\nfrom mlc_llm.model import Model\nfrom mlc_llm.quantization import Quantization\nfrom mlc_llm.support import logging\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nfrom .compiler_flags import ModelConfigOverride, OptimizationFlags\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass CompileArgs:  # pylint: disable=too-many-instance-attributes\n    \"\"\"Arguments to MLC LLM's compiler.\"\"\"\n\n    config: Path\n    quantization: Quantization\n    model: Model\n    target: Target\n    opt: OptimizationFlags\n    build_func: Callable[[IRModule, \"CompileArgs\", Pass], None]\n    system_lib_prefix: str\n    output: Path\n    overrides: ModelConfigOverride\n    debug_dump: Optional[Path]\n\n    def __post_init__(self) -> None:\n        self.opt.update(self.target, self.quantization)\n\n    def display(self) -> None:\n        \"\"\"Display the arguments to stdout.\"\"\"\n        out = StringIO()\n        print(f\"{bold('Compiling with arguments:')}\", file=out)\n        print(f\"  {bold('--config'):<25} {self.config}\", file=out)\n        print(f\"  {bold('--quantization'):<25} {self.quantization}\", file=out)\n        print(f\"  {bold('--model-type'):<25} {self.model.name}\", file=out)\n        print(f\"  {bold('--target'):<25} {self.target.export()}\", file=out)\n        print(f\"  {bold('--opt'):<25} {self.opt}\", file=out)\n        print(f'  {bold(\"--system-lib-prefix\"):<25} \"{self.system_lib_prefix}\"', file=out)\n        print(f\"  {bold('--output'):<25} {self.output}\", file=out)\n        print(f\"  {bold('--overrides'):<25} {self.overrides}\", file=out)\n        # As it's debug only, no need to display\n        # print(f\"  {bold('--debug-dump'):<25} {self.debug_dump}\", file=out)\n        print(out.getvalue().rstrip())\n\n\ndef _apply_preproc_to_params_and_check_pipeline(\n    named_params: List[Tuple[str, nn.Parameter]],\n    model_config,\n) -> Dict[str, tir.PrimFunc]:\n    extra_tirs: Dict[str, tir.PrimFunc] = {}\n    for name, param in named_params:\n        preprocs = param.attrs.get(\"preprocs\", [])\n        shard_strategy = param.attrs.get(\"shard_strategy\", None)\n        if shard_strategy is not None and model_config.tensor_parallel_shards > 1:\n            preprocs.append(\n                shard_strategy.gen_shard_info(\n                    shards=model_config.tensor_parallel_shards,\n                    weight=param,\n                )\n            )\n            if shard_strategy.name not in extra_tirs:\n                extra_tirs[shard_strategy.name] = shard_strategy.gen_tir(\n                    shards=model_config.tensor_parallel_shards,\n                    weight=param,\n                )\n        param.attrs[\"preprocs\"] = preprocs\n\n        pipeline_parallel_stages = getattr(model_config, \"pipeline_parallel_stages\", 1)\n        if pipeline_parallel_stages != 1:\n            assert \"pipeline_stages\" in param.attrs, (\n                f'The pipeline stage is undefined for parameter \"{name}\" when the number '\n                f\"of pipeline parallel stages is {pipeline_parallel_stages}\"\n            )\n        param.attrs[\"pipeline_stages\"] = (\n            [0]\n            if \"pipeline_stages\" not in param.attrs\n            else list(set(param.attrs[\"pipeline_stages\"]))\n        )\n    return extra_tirs\n\n\ndef _infer_kv_state_kind(model_type) -> str:\n    if \"rwkv\" in model_type:\n        return \"rnn_state\"\n    if \"medusa\" in model_type:\n        return \"none\"\n    return \"kv_cache\"\n\n\ndef _compile(args: CompileArgs, model_config: ConfigBase):\n    def _get_variable_bounds(model_config) -> Dict[str, int]:\n        if hasattr(model_config, \"sliding_window_size\"):\n            return {\n                \"rolling_cache_len\": model_config.sliding_window_size,\n                \"kv_seq_len\": model_config.sliding_window_size + model_config.prefill_chunk_size,\n                \"seq_len\": model_config.prefill_chunk_size,\n                \"batch_size\": getattr(model_config, \"max_batch_size\", 1),\n            }\n        return {\n            \"total_seq_len\": model_config.context_window_size,\n            \"seq_len\": model_config.prefill_chunk_size,\n            \"batch_size\": getattr(model_config, \"max_batch_size\", 1),\n        }\n\n    def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]:\n        return {\n            \"name\": name,\n            # Record dynamic shape as -1 (e.g. vocab_size)\n            \"shape\": [s if isinstance(s, int) else s.name for s in param.shape],\n            \"dtype\": param.dtype,\n            \"preprocs\": param.attrs[\"preprocs\"],\n            \"pipeline_stages\": param.attrs.get(\"pipeline_stages\", [0]),\n        }\n\n    logger.info(\"TOP LEVEL MODEL CONFIG BEFORE OVERRIDES: %s\", str(model_config))\n    _kwargs = getattr(model_config, \"kwargs\", {})\n    model_config = args.overrides.apply(model_config)\n    with args.target:\n        op_ext.enable(\n            target=args.target,\n            flashinfer=args.opt.flashinfer,\n            faster_transformer=args.opt.faster_transformer,\n            cutlass=args.opt.cutlass,\n        )\n        # Step 1. Create the quantized model\n        logger.info(\"Creating model from: %s\", model_config)\n        if (\n            args.quantization.kind == \"ft-quant\"\n            and hasattr(model_config, \"tensor_parallel_shards\")\n            and model_config.tensor_parallel_shards > 1  # type: ignore\n        ):\n            raise NotImplementedError\n        if (\n            hasattr(args.quantization, \"linear_weight_layout\")\n            and args.quantization.linear_weight_layout == \"KN\"\n            and hasattr(model_config, \"tensor_parallel_shards\")\n            and model_config.tensor_parallel_shards > 1  # type: ignore\n        ):\n            raise NotImplementedError(\n                \"KN layout (q3f16_0 and q4f16_0) is not supported for tensor parallelism\"\n            )\n        model, _ = args.model.quantize[args.quantization.kind](model_config, args.quantization)\n        # Step 2. Exporting the model to TVM\n        logger.info(\"Exporting the model to TVM compiler\")\n        mod, named_params, ext_mods = model.export_tvm(\n            spec=model.get_default_spec(),  # type: ignore\n            allow_extern=True,\n        )\n        # Step 3. Running relax compilation pipeline\n        logger.info(\"Running optimizations using TVM\")\n        additional_tirs = _apply_preproc_to_params_and_check_pipeline(named_params, model_config)\n        variable_bounds = _get_variable_bounds(model_config)\n        cuda_graph_symbolic_capture_hints = {\n            \"batch_decode\": [\"batch_size\"],\n            \"batch_decode_to_last_hidden_states\": [\"batch_size\"],\n            \"batch_verify\": [\"batch_size\", \"seq_len\"],\n            \"batch_verify_to_last_hidden_states\": [\"batch_size\", \"seq_len\"],\n        }\n        avs = _kwargs.get(\"active_vocab_size\", None)\n        if avs is not None and avs <= 0:\n            avs = None\n        metadata = {\n            \"model_type\": args.model.name,\n            \"quantization\": args.quantization.name,\n            \"context_window_size\": getattr(model_config, \"context_window_size\", -1),\n            \"sliding_window_size\": getattr(model_config, \"sliding_window_size\", -1),\n            \"attention_sink_size\": getattr(model_config, \"attention_sink_size\", -1),\n            \"prefill_chunk_size\": model_config.prefill_chunk_size,  # type: ignore\n            \"tensor_parallel_shards\": model_config.tensor_parallel_shards,  # type: ignore\n            \"pipeline_parallel_stages\": getattr(model_config, \"pipeline_parallel_stages\", 1),\n            \"disaggregation\": getattr(model_config, \"disaggregation\", False),\n            \"kv_state_kind\": _infer_kv_state_kind(args.model.name),\n            \"max_batch_size\": getattr(model_config, \"max_batch_size\", 1),\n            \"active_vocab_size\": avs,\n            \"model_task\": args.model.model_task,\n        }\n        if args.model.embedding_metadata:\n            metadata[\"embedding_metadata\"] = dataclasses.asdict(args.model.embedding_metadata)\n        logger.info(\"Registering metadata: %s\", metadata)\n        metadata[\"params\"] = [_get_param_metadata(name, param) for name, param in named_params]\n        pass_config = {\"relax.backend.use_cuda_graph\": args.opt.cudagraph}\n        # TODO: Remove this workaround when the TVM CSE regression is fixed.\n        # Temporary workaround for TVM CSE regression that can produce\n        # dangling `cse_v*` vars during host codegen.\n        pass_config[\"tir.disable_cse_tir\"] = True\n\n        with PassContext(config=pass_config):\n            args.build_func(\n                mod,\n                args,\n                pipeline=relax.get_pipeline(  # type: ignore\n                    \"mlc_llm\",\n                    target=args.target,\n                    flashinfer=args.opt.flashinfer,\n                    cublas_gemm=args.opt.cublas_gemm,\n                    faster_transformer=args.opt.faster_transformer,\n                    allreduce_strategy=args.opt.ipc_allreduce_strategy,\n                    variable_bounds=variable_bounds,\n                    cuda_graph_symbolic_capture_hints=cuda_graph_symbolic_capture_hints,\n                    additional_tirs=additional_tirs,\n                    ext_mods=ext_mods,\n                    metadata=metadata,\n                    debug_dump=args.debug_dump,\n                ),\n            )\n        _report_memory_usage(metadata=metadata, config=model_config)\n    logger.info(\"Generated: %s\", bold(str(args.output)))\n\n\ndef compile(  # pylint: disable=too-many-arguments,redefined-builtin\n    config: Dict[str, Any],\n    quantization: Quantization,\n    model_type: Model,\n    target: Target,\n    opt: OptimizationFlags,\n    build_func: Callable[[IRModule, CompileArgs, Pass], None],\n    system_lib_prefix: str,\n    output: Path,\n    overrides: ModelConfigOverride,\n    debug_dump: Optional[Path] = None,\n):\n    \"\"\"Compile a model given its configuration and quantization format to a specific target.\"\"\"\n    avs = None\n    if \"active_vocab_size\" in config:\n        avs = config.pop(\"active_vocab_size\")\n        logger.info(\"Active vocab size from input config: %s\", str(avs))\n    if \"model_config\" in config:\n        model_config = config.pop(\"model_config\")\n        model_config.update(config)\n        model_config = model_type.config.from_dict(model_config)\n    else:\n        model_config = model_type.config.from_dict(config)\n    model_config.kwargs = {\"active_vocab_size\": avs} if avs is not None else {}\n    args = CompileArgs(\n        model_config,\n        quantization,\n        model_type,\n        target,\n        opt,\n        build_func,\n        system_lib_prefix,\n        output,\n        overrides,\n        debug_dump,\n    )\n    args.display()\n    _compile(args, model_config)\n"
  },
  {
    "path": "python/mlc_llm/interface/compiler_flags.py",
    "content": "\"\"\"Flags for overriding model config.\"\"\"\n\nimport dataclasses\nimport enum\nfrom io import StringIO\nfrom typing import Optional\n\nfrom mlc_llm.support import argparse, logging\nfrom mlc_llm.support.config import ConfigOverrideBase\n\nlogger = logging.getLogger(__name__)\n\n\nclass IPCAllReduceStrategyType(enum.IntEnum):\n    \"\"\"The all-reduce strategy.\"\"\"\n\n    NONE = 0\n    ONESHOT = 1\n    TWOSHOT = 2\n    AUTO = 3\n\n\n@dataclasses.dataclass\nclass OptimizationFlags:\n    \"\"\"Optimization flags\"\"\"\n\n    flashinfer: bool = False\n    cublas_gemm: bool = False\n    faster_transformer: bool = False\n    cudagraph: bool = False\n    cutlass: bool = False\n    ipc_allreduce_strategy: IPCAllReduceStrategyType = IPCAllReduceStrategyType.NONE\n\n    def __repr__(self) -> str:\n        out = StringIO()\n        print(f\"flashinfer={int(self.flashinfer)}\", file=out, end=\"\")\n        print(f\";cublas_gemm={int(self.cublas_gemm)}\", file=out, end=\"\")\n        print(f\";faster_transformer={int(self.faster_transformer)}\", file=out, end=\"\")\n        print(f\";cudagraph={int(self.cudagraph)}\", file=out, end=\"\")\n        print(f\";cutlass={int(self.cutlass)}\", file=out, end=\"\")\n        print(\n            f\";ipc_allreduce_strategy={self.ipc_allreduce_strategy.name}\",\n            file=out,\n            end=\"\",\n        )\n        return out.getvalue().rstrip()\n\n    @staticmethod\n    def from_str(source: str) -> \"OptimizationFlags\":\n        \"\"\"Parse optimization flags from a string.\"\"\"\n\n        if source in OPT_FLAG_PRESET:\n            return OPT_FLAG_PRESET[source]\n\n        def boolean(value: str) -> bool:\n            if value == \"0\":\n                return False\n            if value == \"1\":\n                return True\n            raise ValueError(f\"Invalid boolean value: {value}\")\n\n        parser = argparse.ArgumentParser(description=\"optimization flags\")\n        parser.add_argument(\"--flashinfer\", type=boolean, default=True)\n        parser.add_argument(\"--cublas_gemm\", type=boolean, default=False)\n        parser.add_argument(\"--faster_transformer\", type=boolean, default=False)\n        parser.add_argument(\"--cudagraph\", type=boolean, default=False)\n        parser.add_argument(\"--cutlass\", type=boolean, default=False)\n        parser.add_argument(\n            \"--ipc_allreduce_strategy\",\n            type=str,\n            choices=[\"NONE\", \"ONESHOT\", \"TWOSHOT\", \"AUTO\"],\n            default=\"NONE\",\n        )\n        results = parser.parse_args([f\"--{i}\" for i in source.split(\";\") if i])\n        return OptimizationFlags(\n            flashinfer=results.flashinfer,\n            cublas_gemm=results.cublas_gemm,\n            faster_transformer=results.faster_transformer,\n            cudagraph=results.cudagraph,\n            cutlass=results.cutlass,\n            ipc_allreduce_strategy=IPCAllReduceStrategyType[results.ipc_allreduce_strategy],\n        )\n\n    def update(self, target, quantization) -> None:\n        \"\"\"Update optimization flags based on additional information.\"\"\"\n\n        def _flashinfer(target) -> bool:\n            from mlc_llm.support.auto_target import (  # pylint: disable=import-outside-toplevel\n                detect_cuda_arch_list,\n            )\n\n            if not self.flashinfer:\n                return False\n            if target.kind.name != \"cuda\":\n                return False\n            arch_list = detect_cuda_arch_list(target)\n            for arch in arch_list:\n                if arch < 80:\n                    logger.warning(\"flashinfer is not supported on CUDA arch < 80\")\n                    return False\n            return True\n\n        def _cublas_gemm(target, quantization) -> bool:\n            \"\"\"correct cublas_gemm flag\"\"\"\n            if not target.kind.name in [\"cuda\", \"rocm\"]:\n                return False\n            if not (\n                quantization.name in [\"q0f16\", \"q0bf16\", \"q0f32\"]\n                or \"e4m3\" in quantization.name\n                or \"e5m2\" in quantization.name\n            ):\n                return False\n            return self.cublas_gemm\n\n        def _faster_transformer(target) -> bool:\n            \"\"\"correct faster_transformer flag\"\"\"\n            if not target.kind.name == \"cuda\":\n                return False\n            return self.faster_transformer\n\n        def _cutlass(target) -> bool:\n            \"\"\"correct cutlass flag\"\"\"\n            if not target.kind.name == \"cuda\":\n                return False\n            return self.cutlass\n\n        def _cudagraph(target) -> bool:\n            \"\"\"correct cudagraph flag\"\"\"\n            if not target.kind.name == \"cuda\":\n                return False\n            return self.cudagraph\n\n        self.flashinfer = _flashinfer(target)\n        self.cublas_gemm = _cublas_gemm(target, quantization)\n        self.faster_transformer = _faster_transformer(target)\n        self.cutlass = _cutlass(target)\n        self.cudagraph = _cudagraph(target)\n\n\n@dataclasses.dataclass\nclass ModelConfigOverride(ConfigOverrideBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Flags for overriding model config.\"\"\"\n\n    context_window_size: Optional[int] = None\n    sliding_window_size: Optional[int] = None\n    prefill_chunk_size: Optional[int] = None\n    attention_sink_size: Optional[int] = None\n    max_batch_size: Optional[int] = None\n    tensor_parallel_shards: Optional[int] = None\n    pipeline_parallel_stages: Optional[int] = None\n    disaggregation: Optional[bool] = None\n\n    def __repr__(self) -> str:\n        out = StringIO()\n        print(f\"context_window_size={self.context_window_size}\", file=out, end=\"\")\n        print(f\";sliding_window_size={self.sliding_window_size}\", file=out, end=\"\")\n        print(f\";prefill_chunk_size={self.prefill_chunk_size}\", file=out, end=\"\")\n        print(f\";attention_sink_size={self.attention_sink_size}\", file=out, end=\"\")\n        print(f\";max_batch_size={self.max_batch_size}\", file=out, end=\"\")\n        print(f\";tensor_parallel_shards={self.tensor_parallel_shards}\", file=out, end=\"\")\n        print(\n            f\";pipeline_parallel_stages={self.pipeline_parallel_stages}\",\n            file=out,\n            end=\"\",\n        )\n        print(f\";disaggregation={self.disaggregation}\", file=out, end=\"\")\n        return out.getvalue().rstrip()\n\n    @staticmethod\n    def from_str(source: str) -> \"ModelConfigOverride\":\n        \"\"\"Parse model config override values from a string.\"\"\"\n        parser = argparse.ArgumentParser(description=\"model config override values\")\n        parser.add_argument(\"--context_window_size\", type=int, default=None)\n        parser.add_argument(\"--sliding_window_size\", type=int, default=None)\n        parser.add_argument(\"--prefill_chunk_size\", type=int, default=None)\n        parser.add_argument(\"--attention_sink_size\", type=int, default=None)\n        parser.add_argument(\"--max_batch_size\", type=int, default=None)\n        parser.add_argument(\"--tensor_parallel_shards\", type=int, default=None)\n        parser.add_argument(\"--pipeline_parallel_stages\", type=int, default=None)\n        parser.add_argument(\n            \"--disaggregation\",\n            type=lambda x: str(x).lower() in [\"true\", \"1\", \"yes\", \"True\"],\n            default=None,\n        )\n        results = parser.parse_args([f\"--{i}\" for i in source.split(\";\") if i])\n        return ModelConfigOverride(\n            context_window_size=results.context_window_size,\n            sliding_window_size=results.sliding_window_size,\n            prefill_chunk_size=results.prefill_chunk_size,\n            attention_sink_size=results.attention_sink_size,\n            max_batch_size=results.max_batch_size,\n            tensor_parallel_shards=results.tensor_parallel_shards,\n            pipeline_parallel_stages=results.pipeline_parallel_stages,\n            disaggregation=results.disaggregation,\n        )\n\n\nOPT_FLAG_PRESET = {\n    \"O0\": OptimizationFlags(\n        flashinfer=False,\n        cublas_gemm=False,\n        cudagraph=False,\n    ),\n    \"O1\": OptimizationFlags(\n        flashinfer=False,\n        cublas_gemm=True,\n        faster_transformer=True,\n        cudagraph=False,\n        cutlass=True,\n    ),\n    \"O2\": OptimizationFlags(\n        flashinfer=True,\n        cublas_gemm=True,\n        faster_transformer=False,\n        cudagraph=True,\n        cutlass=True,\n        ipc_allreduce_strategy=IPCAllReduceStrategyType.NONE,\n    ),\n    \"O3\": OptimizationFlags(\n        flashinfer=True,\n        cublas_gemm=True,\n        faster_transformer=True,\n        cudagraph=True,\n        cutlass=True,\n        ipc_allreduce_strategy=IPCAllReduceStrategyType.AUTO,\n    ),\n}\n"
  },
  {
    "path": "python/mlc_llm/interface/convert_weight.py",
    "content": "\"\"\"Python entrypoint of weight conversion.\"\"\"\n\nimport contextlib\nimport dataclasses\nimport math\nimport os\nimport tempfile\nfrom io import StringIO\nfrom pathlib import Path\nfrom typing import Any, Dict, Iterator, Optional, Tuple\n\nfrom tvm import tir\nfrom tvm.contrib import tvmjs\nfrom tvm.runtime import DataType, Device, Tensor\nfrom tvm.runtime import cpu as cpu_device\nfrom tvm.target import Target\n\nfrom mlc_llm.loader import LOADER\nfrom mlc_llm.model import Model\nfrom mlc_llm.quantization import Quantization\nfrom mlc_llm.support import logging, tqdm\nfrom mlc_llm.support.auto_weight import detect_weight\nfrom mlc_llm.support.preshard import apply_preshard\nfrom mlc_llm.support.style import bold, green\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass ConversionArgs:  # pylint: disable=too-many-instance-attributes\n    \"\"\"Arguments to MLC LLM's weight conversation and quantization flow.\"\"\"\n\n    config: Path\n    quantization: Quantization\n    model: Model\n    device: Device\n    source: Path\n    source_format: str\n    output: Path\n    lora_adapter: Optional[Path] = None\n\n    def display(self) -> None:\n        \"\"\"Display the arguments to stdout.\"\"\"\n\n        def _device_to_str(device: Device) -> str:\n            return f\"{Device._DEVICE_TYPE_TO_NAME[device.dlpack_device_type()]}:{device.index}\"  # pylint: disable=protected-access, line-too-long\n\n        out = StringIO()\n        print(f\"{bold('Weight conversion with arguments:')}\", file=out)\n        print(f\"  {bold('--config'):<25} {self.config}\", file=out)\n        print(f\"  {bold('--quantization'):<25} {self.quantization}\", file=out)\n        print(f\"  {bold('--model-type'):<25} {self.model.name}\", file=out)\n        print(f\"  {bold('--device'):<25} {_device_to_str(self.device)}\", file=out)\n        print(f\"  {bold('--source'):<25} {self.source}\", file=out)\n        print(f\"  {bold('--source-format'):<25} {self.source_format}\", file=out)\n        print(f\"  {bold('--output'):<25} {self.output}\", file=out)\n        if self.lora_adapter is not None:\n            print(f\"  {bold('--lora-adapter'):<25} {self.lora_adapter}\", file=out)\n        print(out.getvalue().rstrip())\n\n\ndef _resolve_base_model_dir(source: Path) -> Path:\n    return source if source.is_dir() else source.parent\n\n\n@contextlib.contextmanager\ndef _merge_lora_adapter_with_base_model(base_source: Path, lora_adapter: Path) -> Iterator[Path]:\n    base_model_dir = _resolve_base_model_dir(base_source)\n    if not base_model_dir.exists():\n        raise ValueError(f\"Base model directory does not exist: {base_model_dir}\")\n    if not lora_adapter.exists() or not lora_adapter.is_dir():\n        raise ValueError(f\"LoRA adapter directory does not exist: {lora_adapter}\")\n\n    try:\n        # pylint: disable=import-outside-toplevel\n        from peft import PeftModel\n        from transformers import AutoModelForCausalLM\n\n        # pylint: enable=import-outside-toplevel\n    except ImportError as err:\n        raise ImportError(\n            \"`--lora-adapter` requires `peft` and `transformers` to be installed.\"\n        ) from err\n\n    with tempfile.TemporaryDirectory() as temp_dir:\n        merged_model_dir = Path(temp_dir) / \"merged_model\"\n        logger.info(\"Merging LoRA adapter %s into base model %s\", lora_adapter, base_model_dir)\n\n        base_model = AutoModelForCausalLM.from_pretrained(\n            str(base_model_dir),\n            torch_dtype=\"auto\",\n            trust_remote_code=False,\n            low_cpu_mem_usage=True,\n        )\n        merged_model = PeftModel.from_pretrained(\n            base_model, str(lora_adapter), is_trainable=False\n        ).merge_and_unload()\n        merged_model.save_pretrained(str(merged_model_dir), safe_serialization=True)\n        yield merged_model_dir\n\n\ndef _convert_args(args: ConversionArgs) -> None:  # pylint: disable=too-many-locals\n    pre_shards_num = os.getenv(\"MLC_INTERNAL_PRESHARD_NUM\")\n    # model config & quantization config\n    model_config = args.model.config.from_file(args.config)\n    if (\n        args.quantization.kind == \"ft-quant\"\n        and hasattr(model_config, \"tensor_parallel_shards\")\n        and model_config.tensor_parallel_shards > 1\n    ):\n        raise NotImplementedError\n    if pre_shards_num is not None:\n        model_config.tensor_parallel_shards = int(pre_shards_num)\n    model, quantize_map = args.model.quantize[args.quantization.kind](\n        model_config, args.quantization\n    )\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),  # type: ignore[attr-defined]\n        allow_extern=True,\n    )\n    named_params = dict(_named_params)\n\n    if pre_shards_num is not None:\n        named_params, preshard_funcs = apply_preshard(named_params, int(pre_shards_num), args)\n    else:\n        preshard_funcs = None\n\n    def _check_param(name: str, param: Tensor):\n        nonlocal named_params\n        if name not in named_params:\n            raise ValueError(f\"Parameter not found in model: {name}\")\n        if name in param_names:\n            raise ValueError(f\"Duplication: Parameter {name} already computed\")\n\n        # Check shape (possibly dynamic)\n        def _check_shape(actual: tuple, expect: tuple):  # expect can have tir.Var\n            if len(actual) != len(expect):\n                return False\n            for actual_i, expect_i in zip(actual, expect):\n                assert isinstance(expect_i, (int, tir.Var))\n                if isinstance(expect_i, int) and actual_i != expect_i:\n                    return False\n            return True\n\n        expect_shape = named_params[name].shape\n        actual_shape = param.shape\n        if not _check_shape(actual_shape, expect_shape):\n            raise ValueError(\n                f\"Parameter {name} has shape {param.shape}, but expected {expect_shape}\"\n            )\n        # Check dtype\n        actual_dtype = param.dtype\n        expect_dtype = named_params[name].dtype\n        if actual_dtype != expect_dtype:\n            raise ValueError(\n                f\"Parameter {name} has dtype {param.dtype}, but expected {expect_dtype}\"\n            )\n        del named_params[name]\n\n    # load and quantize\n    param_names = set()\n    total_bytes = 0.0\n    total_params: int\n\n    def _param_generator() -> Iterator[Tuple[str, Tensor]]:\n        nonlocal total_params, total_bytes\n        with Target.from_device(args.device), tqdm.redirect():\n            loader = LOADER[args.source_format](\n                path=args.source,\n                extern_param_map=args.model.source[args.source_format](\n                    model_config, args.quantization\n                ),\n                quantize_param_map=quantize_map,\n            )\n            for name, param in loader.load(device=args.device, preshard_funcs=preshard_funcs):\n                _check_param(name, param)\n                param_names.add(name)\n                param = param.copyto(cpu_device())\n                total_bytes += math.prod(param.shape) * DataType(param.dtype).itemsize\n                yield name, param\n        total_params = loader.stats.total_param_num\n\n    def _metadata_callback() -> Dict[str, Any]:\n        return {\n            \"ParamSize\": len(param_names),\n            \"ParamBytes\": total_bytes,\n            \"BitsPerParam\": total_bytes * 8.0 / total_params,\n        }\n\n    # dump to output directory\n    tvmjs.dump_tensor_cache(\n        _param_generator(),\n        str(args.output),\n        meta_data=_metadata_callback,\n        encode_format=\"f32-to-bf16\",\n        show_progress=False,\n    )\n    if named_params:\n        raise ValueError(f\"Parameter not found in source: {', '.join(named_params.keys())}\")\n    # Log necessary statistics\n    logger.info(\n        \"%s after quantization: %.3f GB\",\n        green(\"Parameter size\"),\n        total_bytes / (1024**3),\n    )\n    logger.info(f\"%s: {total_params:,}\", green(\"Total parameters\"))\n    logger.info(\n        \"%s: %.3f\",\n        green(\"Bits per parameter\"),\n        total_bytes * 8.0 / total_params,\n    )\n    logger.info(\"Saved to directory: %s\", bold(str(args.output)))\n\n\ndef convert_weight(  # pylint: disable=too-many-arguments\n    config: Path,\n    quantization: Quantization,\n    model: Model,\n    device: Device,\n    source: Path,\n    source_format: str,\n    output: Path,\n    lora_adapter: Optional[Path] = None,\n):\n    \"\"\"MLC LLM's weight conversation and quantization flow.\"\"\"\n    args = ConversionArgs(\n        config, quantization, model, device, source, source_format, output, lora_adapter\n    )\n\n    allowed_lora_source_formats = {\"huggingface-safetensor\", \"huggingface-torch\"}\n    if lora_adapter is not None and source_format not in allowed_lora_source_formats:\n        raise ValueError(\n            \"`--lora-adapter` only supports source formats: \"\n            f\"{sorted(allowed_lora_source_formats)}\"\n        )\n\n    if lora_adapter is not None:\n        with _merge_lora_adapter_with_base_model(source, lora_adapter) as merged_model_dir:\n            merged_source, merged_source_format = detect_weight(\n                weight_path=merged_model_dir,\n                config_json_path=config,\n                weight_format=\"auto\",\n            )\n            merged_args = dataclasses.replace(\n                args, source=merged_source, source_format=merged_source_format\n            )\n            merged_args.display()\n            _convert_args(merged_args)\n            return\n\n    args.display()\n    _convert_args(args)\n"
  },
  {
    "path": "python/mlc_llm/interface/gen_config.py",
    "content": "\"\"\"Generator of mlc-chat-config.json and tokenizer configuration.\"\"\"\n\n# pylint: disable=E1101\nimport dataclasses\nimport json\nimport re\nimport shutil\nfrom dataclasses import asdict\nfrom pathlib import Path\nfrom typing import Optional\n\nfrom mlc_llm.conversation_template import ConvTemplateRegistry\nfrom mlc_llm.model import Model\nfrom mlc_llm.protocol.mlc_chat_config import MLCChatConfig\nfrom mlc_llm.quantization import Quantization\nfrom mlc_llm.support import convert_tiktoken, logging\nfrom mlc_llm.support.style import bold, green, red\nfrom mlc_llm.tokenizers import Tokenizer\n\nfrom .compiler_flags import ModelConfigOverride\n\nlogger = logging.getLogger(__name__)\n\nFOUND = green(\"Found\")\nNOT_FOUND = red(\"Not found\")\nFAILED = red(\"Failed\")\n\n\ndef apply_system_defaults_for_missing_fields(mlc_chat_config: MLCChatConfig) -> None:\n    \"\"\"Apply system default value.\"\"\"\n    for key, value in mlc_chat_config.get_system_defaults_for_missing_fields().items():\n        setattr(mlc_chat_config, key, value)\n        logger.info(\"[System default] Setting %s: %s\", bold(key), value)\n\n\ndef check_string(s: str) -> bool:\n    \"\"\"Check whether it's a string.\"\"\"\n    s = s[1:] if s[0] == \"b\" else s\n    delimit = s[0]\n    if s[-1] != delimit or delimit not in [\"'\", '\"']:\n        return False\n    for i in range(1, len(s) - 1):\n        if s[i] == delimit and s[i - 1] != \"\\\\\":\n            return False\n    return True\n\n\ndef txt2rwkv_tokenizer(vocab: Path, out: Path) -> None:\n    \"\"\"Generate tokenizer_model from RWKV vocab file.\"\"\"\n    idx2token = {}\n\n    with vocab.open(\"r\", encoding=\"utf-8\") as f:\n        lines = f.readlines()\n\n    for l in lines:\n        idx = int(l[: l.index(\" \")])\n        raw = l[l.index(\" \") : l.rindex(\" \")].strip()\n        if check_string(raw):\n            x = eval(raw)  # pylint: disable=eval-used\n            x = x.encode(\"utf-8\") if isinstance(x, str) else x\n            assert isinstance(x, bytes)\n            assert len(x) == int(l[l.rindex(\" \") :])\n            idx2token[idx] = x\n        else:\n            raise ValueError(\"Unsupported vocab dictionary\")\n\n    with (out / \"tokenizer_model\").open(\"wb\") as f:\n        import msgpack  # pylint: disable=import-outside-toplevel,import-error\n\n        msgpack.pack(idx2token, f)\n\n\ndef json2rwkv_tokenizer(vocab: Path, out: Path) -> None:\n    \"\"\"Generate tokenizer_model from RWKV vocab file.\"\"\"\n    idx2token = {}\n\n    with vocab.open(\"r\", encoding=\"utf-8\") as f:\n        data = json.load(f)\n        for key, value in data.items():\n            x = key.encode(\"utf-8\") if isinstance(key, str) else key\n            assert isinstance(x, bytes)\n            idx2token[int(value)] = x\n\n    with (out / \"tokenizer_model\").open(\"wb\") as f:\n        import msgpack  # pylint: disable=import-outside-toplevel,import-error\n\n        msgpack.pack(idx2token, f)\n\n\ndef gen_config(  # pylint: disable=too-many-locals,too-many-arguments,too-many-branches,too-many-statements\n    config: Path,\n    model: Model,\n    quantization: Quantization,\n    conv_template: str,\n    context_window_size: Optional[int],\n    sliding_window_size: Optional[int],\n    prefill_chunk_size: Optional[int],\n    attention_sink_size: Optional[int],\n    tensor_parallel_shards: Optional[int],\n    pipeline_parallel_stages: Optional[int],\n    disaggregation: Optional[bool],\n    max_batch_size: int,\n    output: Path,\n):\n    \"\"\"Entrypoint of MLC Chat configuration generation.\"\"\"\n    # Step 1. Initialize `mlc-chat-config.json` using `config.json`\n    conversation_reg = ConvTemplateRegistry.get_conv_template(conv_template)\n    if conversation_reg is None:\n        logger.warning(\n            \"%s: Conversation template is not registered in ConvTemplateRegistry: %s\",\n            red(\"Warning\"),\n            conv_template,\n        )\n        conversation = conv_template  # type: ignore\n    else:\n        conversation = conversation_reg.to_json_dict()  # type: ignore\n\n    model_config = ModelConfigOverride(\n        context_window_size=context_window_size,\n        sliding_window_size=sliding_window_size,\n        prefill_chunk_size=prefill_chunk_size,\n        attention_sink_size=attention_sink_size,\n        max_batch_size=max_batch_size,\n        tensor_parallel_shards=tensor_parallel_shards,\n        pipeline_parallel_stages=pipeline_parallel_stages,\n        disaggregation=disaggregation,\n    ).apply(model.config.from_file(config))\n    mlc_chat_config = MLCChatConfig(\n        model_type=model.name,\n        quantization=quantization.name,\n        model_config=model_config.asdict(),\n        vocab_size=model_config.vocab_size,\n        active_vocab_size=getattr(model_config, \"active_vocab_size\", model_config.vocab_size),\n        context_window_size=getattr(model_config, \"context_window_size\", -1),\n        sliding_window_size=getattr(model_config, \"sliding_window_size\", -1),\n        prefill_chunk_size=model_config.prefill_chunk_size,\n        attention_sink_size=getattr(model_config, \"attention_sink_size\", -1),\n        tensor_parallel_shards=model_config.tensor_parallel_shards,\n        pipeline_parallel_stages=getattr(model_config, \"pipeline_parallel_stages\", 1),\n        disaggregation=getattr(model_config, \"disaggregation\", False),\n        conv_template=conversation,  # type: ignore\n        model_task=model.model_task,\n        embedding_metadata=(\n            dataclasses.asdict(model.embedding_metadata) if model.embedding_metadata else None\n        ),\n    )\n    # Step 2. Load `generation_config.json` and `config.json` for text-generation related configs\n    for generation_config_filename in [\"generation_config.json\", \"config.json\"]:\n        generation_config = config.parent / generation_config_filename\n        if generation_config.exists():\n            with generation_config.open(\"r\", encoding=\"utf-8\") as in_file:\n                generation_config_json = json.load(in_file)\n            for key, value in generation_config_json.items():\n                if hasattr(mlc_chat_config, key) and getattr(mlc_chat_config, key) is None:\n                    setattr(mlc_chat_config, key, value)\n                    logger.info(\n                        \"[%s] Setting %s: %s\",\n                        generation_config_filename,\n                        bold(key),\n                        value,\n                    )\n        else:\n            logger.info(\"%s %s: %s\", NOT_FOUND, generation_config_filename, generation_config)\n\n    # Step 3. Copy tokenizer configuration\n    # 3.1. Copy over the files and populate mlc_chat_config\n    for filename in TOKENIZER_FILES:\n        file = config.parent / filename\n        if file.exists():\n            mlc_chat_config.tokenizer_files.append(filename)\n            dest = output / filename\n            shutil.copy(file, dest)\n            logger.info(\"%s tokenizer config: %s. Copying to %s\", FOUND, file, bold(str(dest)))\n        else:\n            logger.info(\"%s tokenizer config: %s\", NOT_FOUND, file)\n    # 3.2. Generate `tokenizer_model` for rwkv if `rwkv_vocab_.*` is found\n    pattern = re.compile(r\"rwkv_vocab_v\\d{8}\\.(json|txt)\")\n    for item in config.parent.iterdir():\n        if item.is_file() and pattern.match(item.name):\n            logger.info(\n                \"%s RWKV vocab file: %s. Genetating %s\",\n                FOUND,\n                item,\n                bold(\"tokenizer_model\"),\n            )\n            if item.name.endswith(\".txt\"):\n                txt2rwkv_tokenizer(item, output)\n            else:\n                json2rwkv_tokenizer(item, output)\n    # 3.3. If we have `tokenizer.model` but not `tokenizer.json`, try convert it to\n    # `tokenizer.json` with `transformers`.\n    tokenizer_json_file = config.parent / \"tokenizer.json\"\n    tokenizer_model_file = config.parent / \"tokenizer.model\"\n    if tokenizer_model_file.exists() and (not tokenizer_json_file.exists()):\n        logger.info(\n            \"The model has `tokenizer.model` but not `tokenizer.json`. \"\n            \"It is always recommended to prefer JSON instead. \"\n            \"Attempting to convert using HuggingFace transformers library\"\n        )\n        try:\n            from transformers import (  # pylint: disable=import-error,import-outside-toplevel\n                AutoTokenizer,\n            )\n\n            tokenizer_json_save_dest = output / \"tokenizer.json\"\n            fast_tokenizer = AutoTokenizer.from_pretrained(str(config.parent), use_fast=True)\n            fast_tokenizer.backend_tokenizer.save(str(tokenizer_json_save_dest))\n            mlc_chat_config.tokenizer_files.append(\"tokenizer.json\")\n            logger.info(\n                \"Successfully converted `tokenizer.model` to: %s\",\n                tokenizer_json_save_dest,\n            )\n        except Exception:  # pylint: disable=broad-exception-caught\n            logger.warning(\n                \"Converting to `tokenizer.json` %s with the exception below. \"\n                \"Skipping the conversion.\",\n                FAILED,\n                exc_info=True,\n            )\n    # 3.3. If we still don't have \"tokenizer.json\" at this point, try looking for \"*.tiktoken\" files\n    if (not tokenizer_json_file.exists()) and list(config.parent.glob(\"*.tiktoken\")):\n        try:\n            logger.info(\n                \"The model has tiktoken files but not `tokenizer.json`. \"\n                \"Attempting to convert from tiktoken files\"\n            )\n            convert_tiktoken.convert_tiktoken(\n                str(config.parent), str(output), mlc_chat_config.context_window_size\n            )\n            mlc_chat_config.tokenizer_files.append(\"tokenizer.json\")\n            mlc_chat_config.tokenizer_files.append(\"vocab.json\")\n            mlc_chat_config.tokenizer_files.append(\"merges.txt\")\n            mlc_chat_config.tokenizer_files.append(\"special_tokens_map.json\")\n            logger.info(\"Succesfully converted from tiktoken files to: %s\", str(output))\n        except Exception:  # pylint: disable=broad-exception-caught\n            logger.exception(\"%s with the exception below. Skipping\", FAILED)\n\n    # 3.4. Detect tokenizer info\n    mlc_chat_config.tokenizer_info = asdict(Tokenizer.detect_tokenizer_info(str(output)))\n    logger.info(\"Detected tokenizer info: %s\", mlc_chat_config.tokenizer_info)\n\n    # 3.5. Ensure added_tokens do not have duplicated added_tokens, a mistake from model releaser\n    # that affects correctness of huggingface tokenizer.\n    # See https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/discussions/15.\n    if tokenizer_json_file.exists():\n        with open(tokenizer_json_file, \"r\", encoding=\"utf-8\") as f:\n            tokenizer_json = json.load(f)\n            if \"added_tokens\" in tokenizer_json:\n                appeared_content = set()\n                for added_token in tokenizer_json[\"added_tokens\"]:\n                    content = added_token[\"content\"]\n                    if content in appeared_content:\n                        logger.exception(\n                            \"%s with incorrect tokenizer.json which has duplicated token %s. \"\n                            \"This affects correctness of huggingface tokenizer during runtime, \"\n                            \"please check your tokenizer.json to remove duplication manually.\",\n                            FAILED,\n                            content,\n                        )\n                        raise ValueError(\"Duplicated vocab in tokenizer.json\")\n                    appeared_content.add(content)\n\n    # Step 4. Load system default value\n    apply_system_defaults_for_missing_fields(mlc_chat_config)\n\n    # Step 5. Use HF tokenizer to detect active vocab size via len(tokenizer)\n    if tokenizer_json_file.exists():\n        try:\n            from transformers import (  # pylint: disable=import-error,import-outside-toplevel\n                AutoTokenizer,\n            )\n\n            hf_tokenizer = AutoTokenizer.from_pretrained(str(config.parent), use_fast=True)\n            active_vocab_size = len(hf_tokenizer)\n            if mlc_chat_config.active_vocab_size != active_vocab_size:\n                logger.info(\n                    \"Overriding active_vocab_size from %d to %d using HF tokenizer\",\n                    mlc_chat_config.active_vocab_size,\n                    active_vocab_size,\n                )\n                mlc_chat_config.active_vocab_size = active_vocab_size\n        except Exception:  # pylint: disable=broad-exception-caught\n            logger.warning(\n                \"Detecting active_vocab_size %s with the exception below. Skipping.\",\n                FAILED,\n                exc_info=True,\n            )\n\n    # Step 5. Dump the configuration file to output directory\n    with (output / \"mlc-chat-config.json\").open(\"w\", encoding=\"utf-8\") as out_file:\n        json.dump(mlc_chat_config.model_dump(by_alias=True), out_file, indent=2)\n        logger.info(\"Dumping configuration file to: %s\", bold(out_file.name))\n\n\nTOKENIZER_FILES = [\n    \"tokenizer.model\",\n    \"tokenizer.json\",\n    \"vocab.json\",\n    \"merges.txt\",\n    \"added_tokens.json\",\n    \"tokenizer_config.json\",\n]\n# FIXME: Copy RWKV tokenizer file # pylint: disable=fixme\n\nCONV_TEMPLATES = {\n    \"llama-4\",\n    \"llama-3\",\n    \"llama-3_1\",\n    \"chatml\",\n    \"chatml_nosystem\",\n    \"qwen2\",\n    \"open_hermes_mistral\",\n    \"neural_hermes_mistral\",\n    \"llama_default\",\n    \"llama-2\",\n    \"mistral_default\",\n    \"ministral3\",\n    \"ministral3_reasoning\",\n    \"gpt2\",\n    \"codellama_completion\",\n    \"codellama_instruct\",\n    \"redpajama_chat\",\n    \"rwkv_world\",\n    \"gorilla\",\n    \"gorilla-openfunctions-v2\",\n    \"dolly\",\n    \"oasst\",\n    \"stablelm\",\n    \"LM\",\n    \"stablelm-3b\",\n    \"gpt_bigcode\",\n    \"wizardlm_7b\",\n    \"wizard_coder_or_math\",\n    \"glm\",\n    \"phi-2\",\n    \"phi-3\",\n    \"phi-3-vision\",\n    \"phi-4\",\n    \"stablelm-2\",\n    \"gemma_instruction\",\n    \"gemma3_instruction\",\n    \"orion\",\n    \"llava\",\n    \"hermes2_pro_llama3\",\n    \"hermes3_llama-3_1\",\n    \"tinyllama_v1_0\",\n    \"aya-23\",\n    \"deepseek\",\n    \"deepseek_v2\",\n    \"deepseek_v3\",\n    \"deepseek_r1_qwen\",\n    \"deepseek_r1_llama\",\n    \"olmo\",\n    \"nemotron\",\n    \"llm-jp\",\n}\n"
  },
  {
    "path": "python/mlc_llm/interface/help.py",
    "content": "\"\"\"Help message for CLI arguments.\"\"\"\n\nHELP = {\n    \"config\": (\n        \"\"\"\n1) Path to a HuggingFace model directory that contains a `config.json` or\n2) Path to `config.json` in HuggingFace format, or\n3) The name of a pre-defined model architecture.\n\nA `config.json` file in HuggingFace format defines the model architecture, including the vocabulary\nsize, the number of layers, the hidden size, number of attention heads, etc.\nExample: https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json.\n\nA HuggingFace directory often contains a `config.json` which defines the model architecture,\nthe non-quantized model weights in PyTorch or SafeTensor format, tokenizer configurations,\nas well as an optional `generation_config.json` provides additional default configuration for\ntext generation.\nExample: https://huggingface.co/codellama/CodeLlama-7b-hf/tree/main.\n\"\"\"\n    ).strip(),\n    \"quantization\": \"\"\"\nThe quantization mode we use to compile. If unprovided, will infer from `model`.\n\"\"\".strip(),\n    \"model\": \"\"\"\nA path to ``mlc-chat-config.json``, or an MLC model directory that contains `mlc-chat-config.json`.\nIt can also be a link to a HF repository pointing to an MLC compiled model.\n\"\"\".strip(),\n    \"model_lib\": \"\"\"\nThe full path to the model library file to use (e.g. a ``.so`` file). If unspecified, we will use\nthe provided ``model`` to search over possible paths. It the model lib is not found, it will be\ncompiled in a JIT manner.\n\"\"\".strip(),\n    \"model_type\": \"\"\"\nModel architecture such as \"llama\". If not set, it is inferred from `mlc-chat-config.json`.\n\"\"\".strip(),\n    \"device_compile\": \"\"\"\nThe GPU device to compile the model to. If not set, it is inferred from GPUs available locally.\n\"\"\".strip(),\n    \"device_quantize\": \"\"\"\nThe device used to do quantization such as \"cuda\" or \"cuda:0\". Will detect from local available GPUs\nif not specified.\n\"\"\".strip(),\n    \"device_deploy\": \"\"\"\nThe device used to deploy the model such as \"cuda\" or \"cuda:0\". Will detect from local\navailable GPUs if not specified.\n\"\"\".strip(),\n    \"host\": \"\"\"\nThe host LLVM triple to compile the model to. If not set, it is inferred from the local CPU and OS.\nExamples of the LLVM triple:\n1) iPhones: arm64-apple-ios;\n2) ARM64 Android phones: aarch64-linux-android;\n3) WebAssembly: wasm32-unknown-unknown-wasm;\n4) Windows: x86_64-pc-windows-msvc;\n5) ARM macOS: arm64-apple-darwin.\n\"\"\".strip(),\n    \"opt\": \"\"\"\nOptimization flags. MLC LLM maintains a predefined set of optimization flags,\ndenoted as O0, O1, O2, O3, where O0 means no optimization, O2 means majority of them,\nand O3 represents extreme optimization that could potentially break the system.\nMeanwhile, optimization flags could be explicitly specified via details knobs, e.g.\n--opt=\"cublas_gemm=1;cudagraph=0\".\n\"\"\".strip(),\n    \"system_lib_prefix\": \"\"\"\nAdding a prefix to all symbols exported. Similar to \"objcopy --prefix-symbols\".\nThis is useful when compiling multiple models into a single library to avoid symbol\nconflicts. Different from objcopy, this takes no effect for shared library.\n\"\"\".strip(),\n    \"context_window_size\": \"\"\"\nOption to provide the maximum sequence length supported by the model.\nThis is usually explicitly shown as context length or context window in the model card.\nIf this option is not set explicitly, by default,\nit will be determined by `context_window_size` or `max_position_embeddings` in `config.json`,\nand the latter is usually inaccurate for some models.\n\"\"\".strip(),\n    \"output_compile\": \"\"\"\nThe path to the output file. The suffix determines if the output file is a shared library or\nobjects. Available suffixes:\n1) Linux: .so (shared), .tar (objects);\n2) macOS: .dylib (shared), .tar (objects);\n3) Windows: .dll (shared), .tar (objects);\n4) Android, iOS: .tar (objects);\n5) Web: .wasm (web assembly).\n\"\"\".strip(),\n    \"source\": \"\"\"\nThe path to original model weight, infer from `config` if missing.\n\"\"\".strip(),\n    \"source_format\": \"\"\"\nThe format of source model weight, infer from `config` if missing.\n\"\"\".strip(),\n    \"output_quantize\": \"\"\"\nThe output directory to save the quantized model weight. Will create `params_shard_*.bin` and\n`tensor-cache.json` in this directory.\n\"\"\".strip(),\n    \"conv_template\": \"\"\"\nConversation template. It depends on how the model is tuned. Use \"LM\" for vanilla base model\n\"\"\".strip(),\n    \"output_gen_mlc_chat_config\": \"\"\"\nThe output directory for generated configurations, including `mlc-chat-config.json` and tokenizer\nconfiguration.\n\"\"\".strip(),\n    \"sliding_window_size\": \"\"\"\n(Experimental) The sliding window size in sliding window attention (SWA).\nThis optional field overrides the `sliding_window_size` in config.json for\nthose models that use SWA. Currently only useful when compiling Mistral.\nThis flag subjects to future refactoring.\n\"\"\".strip(),\n    \"prefill_chunk_size\": \"\"\"\n(Experimental) The chunk size during prefilling. By default,\nthe chunk size is the same as sliding window or max sequence length.\nThis flag subjects to future refactoring.\n\"\"\".strip(),\n    \"attention_sink_size\": \"\"\"\n(Experimental) The number of stored sinks. Only supported on Mistral yet. By default,\nthe number of sinks is 4. This flag subjects to future refactoring.\n\"\"\".strip(),\n    \"max_batch_size\": \"\"\"\nThe maximum allowed batch size set for the KV cache to concurrently support.\n\"\"\".strip(),\n    \"\"\"tensor_parallel_shards\"\"\": \"\"\"\nNumber of shards to split the model into in tensor parallelism multi-gpu inference.\n\"\"\".strip(),\n    \"\"\"pipeline_parallel_stages\"\"\": \"\"\"\nNumber of pipeline stages to split the model layers for pipeline parallelism.\n\"\"\".strip(),\n    \"\"\"disaggregation\"\"\": \"\"\"\nWhether enable disaggregation when compiling the model.\n\"\"\".strip(),\n    \"overrides\": \"\"\"\nModel configuration override. Configurations to override `mlc-chat-config.json`. Supports\n`context_window_size`, `prefill_chunk_size`, `sliding_window_size`, `attention_sink_size`,\n`max_batch_size` and `tensor_parallel_shards`. Meanwhile, model config could be explicitly\nspecified via details knobs, e.g. --overrides \"context_window_size=1024;prefill_chunk_size=128\".\n\"\"\".strip(),\n    \"modelconfig_overrides\": \"\"\"\nModel configuration override. Supports overriding,\n`context_window_size`, `prefill_chunk_size`, `sliding_window_size`, `attention_sink_size`,\n`max_num_sequence` and `tensor_parallel_shards`. The overrides could be explicitly\nspecified via details knobs, e.g. --overrides \"context_window_size=1024;prefill_chunk_size=128\".\n\"\"\".strip(),\n    \"debug_dump\": \"\"\"\nSpecifies the directory where the compiler will store its IRs for debugging purposes\nduring various phases of compilation. By default, this is set to `None`, indicating\nthat debug dumping is disabled.\n\"\"\".strip(),\n    \"prompt\": \"\"\"\nThe prompt of the text generation.\n\"\"\".strip(),\n    \"generate_length\": \"\"\"\nThe target length of the text generation.\n\"\"\".strip(),\n    \"max_total_sequence_length_serve\": \"\"\"\nThe KV cache total token capacity, i.e., the maximum total number of tokens that\nthe KV cache support. This decides the GPU memory size that the KV cache consumes.\nIf not specified, system will automatically estimate the maximum capacity based\non the vRAM size on GPU.\n\"\"\".strip(),\n    \"prefill_chunk_size_serve\": \"\"\"\nThe maximum number of tokens the model passes for prefill each time.\nIt should not exceed the prefill chunk size in model config.\nIf not specified, this defaults to the prefill chunk size in model config.\n\"\"\".strip(),\n    \"max_history_size_serve\": \"\"\"\nThe maximum history length for rolling back the RNN state.\nIf unspecified, the default value is 1.\nKV cache does not need this.\n\"\"\".strip(),\n    \"enable_tracing_serve\": \"\"\"\nEnable Chrome Tracing for the server.\nAfter enabling, you can send POST request to the \"debug/dump_event_trace\" entrypoint\nto get the Chrome Trace. For example,\n\"curl -X POST http://127.0.0.1:8000/debug/dump_event_trace -H \"Content-Type: application/json\" -d '{\"model\": \"dist/llama\"}'\"\n\"\"\".strip(),\n    \"mode_serve\": \"\"\"\nThe engine mode in MLC LLM. We provide three preset modes: \"local\", \"interactive\" and \"server\".\nThe default mode is \"local\".\nThe choice of mode decides the values of \"max_num_sequence\", \"max_total_seq_length\" and\n\"prefill_chunk_size\" when they are not explicitly specified.\n1. Mode \"local\" refers to the local server deployment which has low request concurrency.\n   So the max batch size will be set to 4, and max total sequence length and prefill chunk size\n   are set to the context window size (or sliding window size) of the model.\n2. Mode \"interactive\" refers to the interactive use of server, which has at most 1 concurrent\n   request. So the max batch size will be set to 1, and max total sequence length and prefill\n   chunk size are set to the context window size (or sliding window size) of the model.\n3. Mode \"server\" refers to the large server use case which may handle many concurrent request\n   and want to use GPU memory as much as possible. In this mode, we will automatically infer\n   the largest possible max batch size and max total sequence length.\nYou can manually specify arguments \"max_num_sequence\", \"max_total_seq_length\" and\n\"prefill_chunk_size\" via \"--overrides\" to override the automatic inferred values.\nFor example: --overrides \"max_num_sequence=32;max_total_seq_length=4096\"\n\"\"\".strip(),\n    \"additional_models_serve\": \"\"\"\nThe model paths and (optional) model library paths of additional models (other than the main model).\nWhen engine is enabled with speculative decoding, additional models are needed.\nThe way of specifying additional models is:\n\"--additional-models model_path_1 model_path_2 ...\" or\n\"--additional-models model_path_1,model_lib_1 model_path_2 ...\".\nWhen the model lib of a model is not given, JIT model compilation will be activated\nto compile the model automatically.\n\"\"\".strip(),\n    \"gpu_memory_utilization_serve\": \"\"\"\nA number in (0, 1) denoting the fraction of GPU memory used by the server in total.\nIt is used to infer to maximum possible KV cache capacity.\nWhen it is unspecified, it defaults to 0.85.\nUnder mode \"local\" or \"interactive\", the actual memory usage may be significantly smaller than\nthis number. Under mode \"server\", the actual memory usage may be slightly larger than this number.\n\"\"\".strip(),\n    \"speculative_mode_serve\": \"\"\"\nThe speculative decoding mode. Right now four options are supported:\n - \"disable\", where speculative decoding is not enabled,\n - \"small_draft\", denoting the normal speculative decoding (small draft) style,\n - \"eagle\", denoting the eagle-style speculative decoding.\n - \"medusa\", denoting the medusa-style speculative decoding.\nThe default mode is \"disable\".\n\"\"\".strip(),\n    \"spec_draft_length_serve\": \"\"\"\nThe number of draft tokens to generate in speculative proposal.\nBeing 0 means to enable adaptive speculative mode, where the draft length will be\nautomatically adjusted based on engine state. The default values is 0.\n\"\"\".strip(),\n    \"prefix_cache_mode_serve\": \"\"\"\nThe prefix cache mode. Right now two options are supported:\n - \"disable\", where prefix cache is not enabled,\n - \"radix\", denoting the normal paged radix tree based prefix cache,\nThe default mode is \"radix\".\n\"\"\".strip(),\n    \"prefix_cache_max_num_recycling_seqs_serve\": \"\"\"\nThe maximum number of sequences in prefix cache, default as max_batch_size.\nAnd set 0 to disable prefix cache, set -1 to have infinite capacity prefix cache.\n\"\"\".strip(),\n    \"prefill_mode\": \"\"\"\nThe prefill mode. \"chunked\" means the basic prefill with chunked input enabled. \"hybrid\" means the\nhybrid prefill or split-fuse, so that decode step will be converted into prefill.\n\"\"\".strip(),\n    \"overrides_serve\": \"\"\"\nOverriding extra configurable fields of EngineConfig and model compilation config.\nSupporting fields that can be be overridden: \"tensor_parallel_shards\", \"max_num_sequence\",\n\"max_total_seq_length\", \"prefill_chunk_size\", \"max_history_size\", \"gpu_memory_utilization\",\n\"spec_draft_length\", \"prefix_cache_max_num_recycling_seqs\", \"context_window_size\",\n\"sliding_window_size\", \"attention_sink_size\".\nPlease check out the documentation of EngineConfig in mlc_llm/serve/config.py for detailed docstring\nof each field.\nExample: --overrides \"max_num_sequence=32;max_total_seq_length=4096;tensor_parallel_shards=2\"\n\"\"\".strip(),\n    \"config_package\": \"\"\"\nThe path to \"mlc-package-config.json\" which is used for package build.\nSee \"https://github.com/mlc-ai/mlc-llm/blob/main/ios/MLCChat/mlc-package-config.json\" as an example.\n\"\"\".strip(),\n    \"mlc_llm_source_dir\": \"\"\"\nThe source code path to MLC LLM.\n\"\"\".strip(),\n    \"output_package\": \"\"\"\nThe path of output directory for the package build outputs.\n\"\"\".strip(),\n    \"calibration_dataset\": \"\"\"\nThe path to the calibration dataset.\n    \"\"\".strip(),\n    \"num_calibration_samples\": \"\"\"\nThe number of samples used for calibration.\n    \"\"\".strip(),\n    \"output_calibration\": \"\"\"\nThe output directory to save the calibration params.\n    \"\"\".strip(),\n    \"seed_calibrate\": \"\"\"\nThe seed to sample the calibration dataset.\"\"\",\n    \"pd_balance_factor\": \"\"\"\nHow much prefill to move to decode engine. For example,\n0.1 means the last 10 percent tokens are prefilled by decode engine.\n    \"\"\".strip(),\n}\n"
  },
  {
    "path": "python/mlc_llm/interface/jit.py",
    "content": "\"\"\"Just-in-time compilation of MLC-Chat models.\"\"\"\n\nimport dataclasses\nimport hashlib\nimport json\nimport os\nimport shlex\nimport shutil\nimport subprocess\nimport sys\nimport tempfile\nfrom pathlib import Path\nfrom typing import Any, Dict, Optional, Union\n\nfrom tvm.runtime import Device\n\nfrom mlc_llm.model import MODELS\nfrom mlc_llm.support import logging\nfrom mlc_llm.support.auto_device import device2str\nfrom mlc_llm.support.constants import (\n    MLC_DSO_SUFFIX,\n    MLC_JIT_POLICY,\n    MLC_LLM_HOME,\n    MLC_TEMP_DIR,\n)\nfrom mlc_llm.support.style import blue, bold\n\nfrom .compiler_flags import ModelConfigOverride, OptimizationFlags\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass JITResult:\n    \"\"\"The jit compilation result class.\"\"\"\n\n    model_lib_path: str\n    system_lib_prefix: Optional[str] = None\n\n\ndef log_jit_policy():\n    \"\"\"log current jit policy\"\"\"\n    logger.info(\n        \"%s = %s. Can be one of: ON, OFF, REDO, READONLY\",\n        bold(\"MLC_JIT_POLICY\"),\n        MLC_JIT_POLICY,\n    )\n\n\ndef jit(  # pylint: disable=too-many-locals,too-many-statements\n    model_path: Path,\n    overrides: Dict[str, Any],\n    device: Union[Device, str],\n    system_lib_prefix: Optional[str] = None,\n    *,\n    skip_log_jit_policy=False,\n) -> JITResult:\n    \"\"\"Just-in-time compile a MLC-Chat model.\"\"\"\n    # skip logging jit policy since when outside can hint once\n    if not skip_log_jit_policy:\n        log_jit_policy()\n\n    if MLC_JIT_POLICY == \"OFF\":\n        raise RuntimeError(\"JIT is disabled by MLC_JIT_POLICY=OFF\")\n\n    with open(model_path / \"mlc-chat-config.json\", \"r\", encoding=\"utf-8\") as in_file:\n        mlc_chat_config = json.load(in_file)\n    model_type = mlc_chat_config.pop(\"model_type\")\n    quantization = mlc_chat_config.pop(\"quantization\")\n    lib_suffix = MLC_DSO_SUFFIX if device not in [\"iphone\", \"macabi\", \"android\"] else \"tar\"\n\n    def _get_optimization_flags() -> str:\n        opt = overrides.pop(\"opt\", None)\n        if opt is None:\n            opt = \"O2\"\n        return repr(OptimizationFlags.from_str(opt))\n\n    def _get_overrides() -> str:\n        forbid_list = [\n            \"context_window_size\",\n            \"sliding_window_size\",\n            \"attention_sink_size\",\n        ]\n        result = []\n        for field in dataclasses.fields(ModelConfigOverride):\n            value = overrides.get(field.name, None)\n            if value is not None:\n                if field.name in forbid_list and value == -1:\n                    continue\n                result.append(f\"{field.name}={value}\")\n        return \";\".join(result)\n\n    def _get_model_config() -> Dict[str, Any]:\n        model_config = mlc_chat_config.pop(\"model_config\")\n        model_config.update(mlc_chat_config)\n        for field in dataclasses.fields(ModelConfigOverride):\n            value = overrides.get(field.name, None)\n            if value is not None:\n                model_config[field.name] = value\n        return MODELS[model_type].config.from_dict(model_config).asdict()\n\n    def _run_jit(\n        opt: str,\n        overrides: str,\n        device: str,\n        system_lib_prefix: Optional[str],\n        dst: str,\n    ):\n        with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir:\n            dso_path = os.path.join(tmp_dir, f\"lib.{lib_suffix}\")\n            cmd = [\n                sys.executable,\n                \"-m\",\n                \"mlc_llm\",\n                \"compile\",\n                str(model_path),\n                \"--opt\",\n                opt,\n                \"--overrides\",\n                overrides,\n                \"--device\",\n                device,\n                \"--output\",\n                dso_path,\n            ]\n            if system_lib_prefix:\n                cmd += [\"--system-lib-prefix\", system_lib_prefix + \"_\"]\n            logger.info(\"Compiling using commands below:\")\n            logger.info(\"%s\", blue(shlex.join(cmd)))\n            subprocess.run(cmd, check=False, env=os.environ)\n            # note on windows: compilation can succeed but return code is still nonzero\n            # check whether file exists instead\n            if not os.path.isfile(dso_path):\n                raise RuntimeError(\"Cannot find compilation output, compilation failed\")\n            shutil.move(dso_path, dst)\n            logger.info(\"Using compiled model lib: %s\", bold(dst))\n\n    hash_key = {\n        \"model_config\": _get_model_config(),\n        \"overrides\": _get_overrides(),\n        \"opt\": _get_optimization_flags(),\n        \"device\": device2str(device) if isinstance(device, Device) else device,\n        \"model_type\": model_type,\n        \"quantization\": quantization,\n    }\n    if device in [\"iphone\", \"macabi\", \"android\"]:\n        if system_lib_prefix is None:\n            system_lib_hash_value = hashlib.md5(\n                json.dumps(\n                    hash_key,\n                    sort_keys=True,\n                    indent=2,\n                ).encode(\"utf-8\")\n            ).hexdigest()\n            system_lib_prefix = f\"{model_type}_{quantization}_{system_lib_hash_value}\".replace(\n                \"-\", \"_\"\n            )\n        hash_key[\"system_lib_prefix\"] = system_lib_prefix\n    hash_value = hashlib.md5(\n        json.dumps(\n            hash_key,\n            sort_keys=True,\n            indent=2,\n        ).encode(\"utf-8\")\n    ).hexdigest()\n    dst = MLC_LLM_HOME / \"model_lib\" / f\"{hash_value}.{lib_suffix}\"\n    if dst.is_file() and MLC_JIT_POLICY in [\"ON\", \"READONLY\"]:\n        logger.info(\"Using cached model lib: %s\", bold(str(dst)))\n        return JITResult(str(dst), system_lib_prefix)\n    if MLC_JIT_POLICY == \"READONLY\":\n        raise RuntimeError(\n            \"No cached model lib found, and JIT is disabled by MLC_JIT_POLICY=READONLY\"\n        )\n    _run_jit(\n        opt=hash_key[\"opt\"],\n        overrides=hash_key[\"overrides\"],\n        device=hash_key[\"device\"],\n        system_lib_prefix=system_lib_prefix,\n        dst=str(dst),\n    )\n    return JITResult(str(dst), system_lib_prefix)\n"
  },
  {
    "path": "python/mlc_llm/interface/package.py",
    "content": "\"\"\"Python entrypoint of package.\"\"\"\n\nimport dataclasses\nimport json\nimport os\nimport shutil\nimport subprocess\nimport sys\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Literal\n\nfrom mlc_llm.interface import jit\nfrom mlc_llm.support import download_cache, logging, style\n\nlogging.enable_logging()\nlogger = logging.getLogger(__name__)\n\nSUPPORTED_DEVICES = [\"iphone\", \"macabi\", \"android\"]\n\n\ndef build_model_library(  # pylint: disable=too-many-branches,too-many-locals,too-many-statements\n    package_config: Dict[str, Any], device: str, bundle_dir: Path, app_config_path: Path\n) -> Dict[str, str]:\n    \"\"\"Build model libraries. Return the dictionary of \"library prefix to lib path\".\"\"\"\n    # - Create the bundle directory.\n    os.makedirs(bundle_dir, exist_ok=True)\n    # Clean up all the directories in `output/bundle`.\n    logger.info('Clean up all directories under \"%s\"', str(bundle_dir))\n    for content_path in bundle_dir.iterdir():\n        if content_path.is_dir():\n            shutil.rmtree(content_path)\n\n    # - Process each model, and prepare the app config.\n    app_config_model_list = []\n\n    model_entries = package_config.get(\"model_list\", [])\n    if not isinstance(model_entries, list):\n        raise ValueError('The \"model_list\" in \"mlc-package-config.json\" is expected to be a list.')\n    model_lib_path_for_prepare_libs = package_config.get(\"model_lib_path_for_prepare_libs\", {})\n    if not isinstance(model_lib_path_for_prepare_libs, dict):\n        raise ValueError(\n            'The \"model_lib_path_for_prepare_libs\" in \"mlc-package-config.json\" is expected to be '\n            \"a dict.\"\n        )\n\n    jit.log_jit_policy()\n\n    for model_entry in package_config.get(\"model_list\", []):\n        # - Parse model entry.\n        if not isinstance(model_entry, dict):\n            raise ValueError('The element of \"model_list\" is expected to be a dict.')\n        model = model_entry[\"model\"]\n        model_id = model_entry[\"model_id\"]\n        bundle_weight = model_entry.get(\"bundle_weight\", False)\n        overrides = model_entry.get(\"overrides\", {})\n        model_lib = model_entry.get(\"model_lib\", None)\n\n        estimated_vram_bytes = model_entry[\"estimated_vram_bytes\"]\n        if not isinstance(model, str):\n            raise ValueError('The value of \"model\" in \"model_list\" is expected to be a string.')\n        if not isinstance(model_id, str):\n            raise ValueError('The value of \"model_id\" in \"model_list\" is expected to be a string.')\n        if not isinstance(bundle_weight, bool):\n            raise ValueError(\n                'The value of \"bundle_weight\" in \"model_list\" is expected to be a boolean.'\n            )\n        if not isinstance(overrides, dict):\n            raise ValueError('The value of \"overrides\" in \"model_list\" is expected to be a dict.')\n        if model_lib is not None and not isinstance(model_lib, str):\n            raise ValueError('The value of \"model_lib\" in \"model_list\" is expected to be string.')\n\n        # - Load model config. Download happens when needed.\n        model_path = download_cache.get_or_download_model(model)\n\n        # - Jit compile if the model lib path is not specified.\n        model_lib_path = (\n            model_lib_path_for_prepare_libs.get(model_lib, None) if model_lib is not None else None\n        )\n        if model_lib_path is None:\n            if model_lib is None:\n                logger.info(\n                    'Model lib is not specified for model \"%s\". Now jit compile the model library.',\n                    model_id,\n                )\n            else:\n                logger.info(\n                    'Model lib path for \"%s\" is not specified in \"model_lib_path_for_prepare_libs\".'\n                    \"Now jit compile the model library.\",\n                    model_lib,\n                )\n            model_lib_path, model_lib = dataclasses.astuple(\n                jit.jit(\n                    model_path=model_path,\n                    overrides=overrides,\n                    device=device,\n                    system_lib_prefix=model_lib,\n                    skip_log_jit_policy=True,\n                )\n            )\n            assert model_lib is not None\n            model_lib_path_for_prepare_libs[model_lib] = model_lib_path\n\n        # - Set \"model_url\"/\"model_path\" and \"model_id\"\n        app_config_model_entry = {}\n        is_local_model = not model.startswith(\"HF://\") and not model.startswith(\"https://\")\n        app_config_model_entry[\"model_id\"] = model_id\n        app_config_model_entry[\"model_lib\"] = model_lib\n\n        # - Bundle weight\n        if is_local_model and not bundle_weight:\n            raise ValueError(\n                f'Model \"{model}\" in \"model_list\" is a local path.'\n                f'Please set \\'\"bundle_weight\": true\\' in the entry of model \"{model}\".'\n            )\n        if bundle_weight:\n            if not os.path.isfile(model_path / \"tensor-cache.json\"):\n                raise ValueError(\n                    f'Bundle weight is set for model \"{model}\". However, model weights are not'\n                    f'found under the directory \"{model}\". '\n                    + (\n                        \"Please follow https://llm.mlc.ai/docs/compilation/convert_weights.html to \"\n                        \"convert model weights.\"\n                        if is_local_model\n                        else \"Please report this issue to https://github.com/mlc-ai/mlc-llm/issues.\"\n                    )\n                )\n            # Overwrite the model weight directory in bundle.\n            bundle_model_weight_path = bundle_dir / model_id\n            logger.info(\n                \"Bundle weight for %s, copy into %s\",\n                style.bold(model_id),\n                style.bold(str(bundle_model_weight_path)),\n            )\n            if bundle_model_weight_path.exists():\n                shutil.rmtree(bundle_model_weight_path)\n            shutil.copytree(model_path, bundle_model_weight_path)\n        if bundle_weight and device in [\"iphone\", \"macabi\"]:\n            app_config_model_entry[\"model_path\"] = model_id\n        else:\n            app_config_model_entry[\"model_url\"] = model.replace(\"HF://\", \"https://huggingface.co/\")\n\n        # - estimated_vram_bytes\n        app_config_model_entry[\"estimated_vram_bytes\"] = estimated_vram_bytes\n\n        app_config_model_list.append(app_config_model_entry)\n\n    # - Dump \"mlc-app-config.json\".\n    app_config_json_str = json.dumps(\n        {\"model_list\": app_config_model_list},\n        indent=2,\n    )\n    with open(app_config_path, \"w\", encoding=\"utf-8\") as file:\n        print(app_config_json_str, file=file)\n        logger.info(\n            'Dump the app config below to \"%s\":\\n%s',\n            str(app_config_path),\n            style.green(app_config_json_str),\n        )\n    return model_lib_path_for_prepare_libs\n\n\ndef validate_model_lib(  # pylint: disable=too-many-locals,too-many-statements\n    app_config_path: Path,\n    package_config_path: Path,\n    model_lib_path_for_prepare_libs: dict,\n    device: Literal[\"iphone\", \"macabi\", \"android\"],\n    output: Path,\n) -> None:\n    \"\"\"Validate the model lib prefixes of model libraries.\"\"\"\n    # pylint: disable=import-outside-toplevel,redefined-outer-name,shadowed-import,reimported\n    if device == \"android\":\n        from tvm.contrib import ndk as cc\n    else:\n        from tvm.contrib import cc\n    # pylint: enable=import-outside-toplevel,redefined-outer-name,shadowed-import,reimported\n\n    with open(app_config_path, \"r\", encoding=\"utf-8\") as file:\n        app_config = json.load(file)\n\n    tar_list = []\n    model_set = set()\n\n    for model, model_lib_path in model_lib_path_for_prepare_libs.items():\n        model_lib_path = os.path.join(model_lib_path)\n        lib_path_valid = os.path.isfile(model_lib_path)\n        if not lib_path_valid:\n            raise RuntimeError(f\"Cannot find file {model_lib_path} as an {device} model library\")\n        tar_list.append(model_lib_path)\n        model_set.add(model)\n\n    os.makedirs(output / \"lib\", exist_ok=True)\n    if device in [\"iphone\", \"macabi\"]:\n        lib_name = \"libmodel_iphone.a\"\n    else:\n        lib_name = \"libmodel_android.a\"\n    lib_path = output / \"lib\" / lib_name\n\n    def _get_model_libs(lib_path: Path) -> List[str]:\n        \"\"\"Get the model lib prefixes in the given static lib path.\"\"\"\n        global_symbol_map = cc.get_global_symbol_section_map(lib_path)\n        libs = []\n        suffix = \"___tvm_ffi__library_bin\"\n        for name, _ in global_symbol_map.items():\n            if name.endswith(suffix):\n                model_lib = name[: -len(suffix)]\n                if model_lib.startswith(\"_\"):\n                    model_lib = model_lib[1:]\n                libs.append(model_lib)\n        return libs\n\n    cc.create_staticlib(lib_path, tar_list)\n    available_model_libs = _get_model_libs(lib_path)\n    logger.info(\"Creating lib from %s\", str(tar_list))\n    logger.info(\"Validating the library %s\", str(lib_path))\n    logger.info(\n        \"List of available model libs packaged: %s,\"\n        \" if we have '-' in the model_lib string, it will be turned into '_'\",\n        str(available_model_libs),\n    )\n    global_symbol_map = cc.get_global_symbol_section_map(lib_path)\n    error_happened = False\n\n    for item in app_config[\"model_list\"]:\n        model_lib = item[\"model_lib\"]\n        model_id = item[\"model_id\"]\n        if model_lib not in model_set:\n            # NOTE: this cannot happen under new setting\n            # since if model_lib is not included, it will be jitted\n            raise RuntimeError(\n                f\"ValidationError: model_lib={model_lib} specified for model_id={model_id} \"\n                \"is not included in model_lib_path_for_prepare_libs argument, \"\n                \"This will cause the specific model not being able to load, \"\n                f\"model_lib_path_for_prepare_libs={model_lib_path_for_prepare_libs}\"\n            )\n\n        model_prefix_pattern = model_lib.replace(\"-\", \"_\") + \"___tvm_ffi__library_bin\"\n        if (\n            model_prefix_pattern not in global_symbol_map\n            and \"_\" + model_prefix_pattern not in global_symbol_map\n        ):\n            # NOTE: no lazy format is ok since this is a slow pass\n            model_lib_path = model_lib_path_for_prepare_libs[model_lib]\n            log_msg = (\n                \"ValidationError:\\n\"\n                f\"\\tmodel_lib {model_lib} requested in {str(app_config_path)}\"\n                f\" is not found in {str(lib_path)}\\n\"\n                f\"\\tspecifically the model_lib for {model_lib_path}.\\n\"\n                f\"\\tcurrent available model_libs in {str(lib_path)}: {available_model_libs}\\n\"\n                f\"\\tThis can happen when we manually specified model_lib_path_for_prepare_libs\"\n                f\" in {str(package_config_path)}\\n\"\n                f\"\\tConsider remove model_lib_path_for_prepare_libs (so library can be jitted)\"\n                \"or check the compile command\"\n            )\n            logger.info(log_msg)\n            error_happened = True\n\n    if not error_happened:\n        logger.info(style.green(\"Validation pass\"))\n    else:\n        logger.info(style.red(\"Validation failed\"))\n        sys.exit(255)\n\n\ndef build_android_binding(mlc_llm_source_dir: Path, output: Path) -> None:\n    \"\"\"Build android binding in MLC LLM\"\"\"\n    mlc4j_path = mlc_llm_source_dir / \"android\" / \"mlc4j\"\n\n    # Move the model libraries to \"build/lib/\" for linking\n    os.makedirs(Path(\"build\") / \"lib\", exist_ok=True)\n    src_path = str(output / \"lib\" / \"libmodel_android.a\")\n    dst_path = str(Path(\"build\") / \"lib\" / \"libmodel_android.a\")\n    logger.info('Moving \"%s\" to \"%s\"', src_path, dst_path)\n    shutil.move(src_path, dst_path)\n\n    # Build mlc4j\n    logger.info(\"Building mlc4j\")\n    subprocess.run([sys.executable, mlc4j_path / \"prepare_libs.py\"], check=True, env=os.environ)\n    # Copy built files back to output directory.\n    lib_path = output / \"lib\" / \"mlc4j\"\n    os.makedirs(lib_path, exist_ok=True)\n    logger.info('Clean up all directories under \"%s\"', str(lib_path))\n    for content_path in lib_path.iterdir():\n        if content_path.is_dir():\n            shutil.rmtree(content_path)\n\n    src_path = str(mlc4j_path / \"src\")\n    dst_path = str(lib_path / \"src\")\n    logger.info('Copying \"%s\" to \"%s\"', src_path, dst_path)\n    shutil.copytree(src_path, dst_path)\n\n    src_path = str(mlc4j_path / \"build.gradle\")\n    dst_path = str(lib_path / \"build.gradle\")\n    logger.info('Copying \"%s\" to \"%s\"', src_path, dst_path)\n    shutil.copy(src_path, dst_path)\n\n    src_path = str(Path(\"build\") / \"output\")\n    dst_path = str(lib_path / \"output\")\n    logger.info('Copying \"%s\" to \"%s\"', src_path, dst_path)\n    shutil.copytree(src_path, dst_path)\n\n    os.makedirs(lib_path / \"src\" / \"main\" / \"assets\")\n    src_path = str(output / \"bundle\" / \"mlc-app-config.json\")\n    dst_path = str(lib_path / \"src\" / \"main\" / \"assets\" / \"mlc-app-config.json\")\n    logger.info('Moving \"%s\" to \"%s\"', src_path, dst_path)\n    shutil.move(src_path, dst_path)\n\n\ndef build_iphone_binding(mlc_llm_source_dir: Path, output: Path) -> None:\n    \"\"\"Build iOS binding in MLC LLM\"\"\"\n    # Build iphone binding\n    logger.info(\"Build iphone binding\")\n    subprocess.run(\n        [\"bash\", mlc_llm_source_dir / \"ios\" / \"prepare_libs.sh\"],\n        check=True,\n        env=os.environ,\n    )\n\n    # Copy built libraries back to output directory.\n    for static_library in (Path(\"build\") / \"lib\").iterdir():\n        dst_path = str(output / \"lib\" / static_library.name)\n        logger.info('Copying \"%s\" to \"%s\"', static_library, dst_path)\n        shutil.copy(static_library, dst_path)\n\n\ndef build_macabi_binding(mlc_llm_source_dir: Path, output: Path) -> None:\n    \"\"\"Build Mac Catalyst binding in MLC LLM\"\"\"\n    deployment_target = os.environ.get(\"MLC_MACABI_DEPLOYMENT_TARGET\", \"18.0\")\n    macabi_arch = os.environ.get(\"MLC_MACABI_ARCH\", \"\").strip() or \"arm64\"\n    logger.info(\"Build macabi binding (deployment target %s)\", deployment_target)\n    cmd = [\n        \"bash\",\n        str(mlc_llm_source_dir / \"ios\" / \"prepare_libs.sh\"),\n        \"--catalyst\",\n        \"--deployment-target\",\n        deployment_target,\n    ]\n    if macabi_arch:\n        cmd += [\"--arch\", macabi_arch]\n    subprocess.run(cmd, check=True, env=os.environ)\n\n    # Copy built libraries back to output directory.\n    build_dir = Path(f\"build-maccatalyst-{macabi_arch}\")\n    for static_library in (build_dir / \"lib\").iterdir():\n        dst_path = str(output / \"lib\" / static_library.name)\n        logger.info('Copying \"%s\" to \"%s\"', static_library, dst_path)\n        shutil.copy(static_library, dst_path)\n\n\ndef package(\n    package_config_path: Path,\n    mlc_llm_source_dir: Path,\n    output: Path,\n) -> None:\n    \"\"\"Python entrypoint of package.\"\"\"\n    logger.info('MLC LLM HOME: \"%s\"', mlc_llm_source_dir)\n\n    # - Read package config.\n    with open(package_config_path, \"r\", encoding=\"utf-8\") as file:\n        package_config = json.load(file)\n    if not isinstance(package_config, dict):\n        raise ValueError(\n            \"The content of MLC package config is expected to be a dict with \"\n            f'field \"model_list\". However, the content of \"{package_config_path}\" is not a dict.'\n        )\n\n    # - Read device.\n    if \"device\" not in package_config:\n        raise ValueError(f'JSON file \"{package_config_path}\" is required to have field \"device\".')\n    device = package_config[\"device\"]\n    if device not in SUPPORTED_DEVICES:\n        raise ValueError(\n            f'The \"device\" field of JSON file {package_config_path} is expected to be one of '\n            f'{SUPPORTED_DEVICES}, while \"{device}\" is given in the JSON.'\n        )\n\n    bundle_dir = output / \"bundle\"\n    app_config_path = bundle_dir / \"mlc-app-config.json\"\n    # - Build model libraries.\n    model_lib_path_for_prepare_libs = build_model_library(\n        package_config, device, bundle_dir, app_config_path\n    )\n    # - Validate model libraries.\n    validate_model_lib(\n        app_config_path,\n        package_config_path,\n        model_lib_path_for_prepare_libs,\n        device,\n        output,\n    )\n\n    # - Copy model libraries\n    if device == \"android\":\n        build_android_binding(mlc_llm_source_dir, output)\n    elif device == \"iphone\":\n        build_iphone_binding(mlc_llm_source_dir, output)\n    elif device == \"macabi\":\n        build_macabi_binding(mlc_llm_source_dir, output)\n    else:\n        assert False, \"Cannot reach here\"\n\n    logger.info(\"All finished.\")\n"
  },
  {
    "path": "python/mlc_llm/interface/router.py",
    "content": "\"\"\"Python entrypoint of router.\"\"\"\n\n# pylint: disable=fixme\nfrom http import HTTPStatus\nfrom typing import AsyncGenerator, List, Literal, Optional, Type\n\nimport fastapi\nimport uvicorn\nfrom fastapi.middleware.cors import CORSMiddleware\n\nfrom mlc_llm.protocol import error_protocol\nfrom mlc_llm.protocol.openai_api_protocol import CompletionLogProbs, CompletionRequest\nfrom mlc_llm.router import Router\nfrom mlc_llm.serve import engine_base, engine_utils\n\n\ndef serve(\n    model: str,\n    model_lib: Optional[str],\n    router_host: str,\n    router_port: int,\n    endpoint_hosts: List[str],\n    endpoint_ports: List[int],\n    endpoint_num_gpus: List[int],\n    enable_prefix_cache: bool,\n    router_mode: Literal[\"disagg\", \"round-robin\"] = \"round-robin\",\n    pd_balance_factor: float = 0.0,\n    router_type: Type[Router] = Router,\n):  # pylint: disable=too-many-arguments\n    \"\"\"Start the router with the specified configuration.\"\"\"\n    # 1. Instantiate router\n    router = router_type(\n        model=model,\n        model_lib=model_lib,\n        hosts=endpoint_hosts,\n        ports=endpoint_ports,\n        num_gpus=endpoint_num_gpus,\n        enable_prefix_cache=enable_prefix_cache,\n        router_mode=router_mode,\n        pd_balance_factor=pd_balance_factor,\n    )\n\n    router_app = fastapi.APIRouter()\n\n    @router_app.post(\"/v1/completions\")\n    async def request_completion(request: CompletionRequest, raw_request: fastapi.Request):\n        \"\"\"OpenAI-compatible completion API.\n        API reference: https://platform.openai.com/docs/api-reference/completions/create\n        \"\"\"\n        if router is None:\n            return error_protocol.create_error_response(\n                HTTPStatus.BAD_REQUEST, message=\"Router is not initialized.\"\n            )\n        request_id = f\"cmpl-{engine_utils.random_uuid()}\"\n\n        # Streaming response.\n        if request.stream:\n            # We manually get the first response from generator to\n            # capture potential exceptions in this scope, rather then\n            # the StreamingResponse scope.\n            stream_generator = router.handle_completion(  # pylint: disable=protected-access\n                request, request_id\n            )\n            first_response = await anext(  # type: ignore  # pylint: disable=undefined-variable\n                stream_generator\n            )\n\n            async def completion_stream_generator() -> AsyncGenerator[str, None]:\n                if isinstance(first_response, StopAsyncIteration):\n                    yield \"data: [DONE]\\n\\n\"\n                    return\n                yield f\"data: {first_response.model_dump_json(by_alias=True)}\\n\\n\"\n                async for response in stream_generator:\n                    yield f\"data: {response.model_dump_json(by_alias=True)}\\n\\n\"\n                yield \"data: [DONE]\\n\\n\"\n\n            return fastapi.responses.StreamingResponse(\n                completion_stream_generator(), media_type=\"text/event-stream\"\n            )\n\n        # FIXME: Non-streaming response not fully implemented\n        request_final_usage = None\n        output_texts = [\"\"] * request.n\n        finish_reasons: List[Optional[str]] = [None] * request.n\n        logprob_results: List[Optional[CompletionLogProbs]] = [None] * request.n\n\n        async for response in router.handle_completion(  # pylint: disable=protected-access\n            request, request_id\n        ):\n            if await raw_request.is_disconnected():\n                # In non-streaming cases, the engine will not be notified\n                # when the request is disconnected.\n                # Therefore, we check if it is disconnected each time,\n                # and explicitly return.\n                # Note that requesta abort is triggered when the async for and funciton scope ends.\n                return error_protocol.create_error_response(\n                    HTTPStatus.BAD_REQUEST, message=\"The request has disconnected\"\n                )\n            # TODO(Charlie): This is copied from engine.py --\n            # why is it here? Non-streaming only has a single chunk right?\n            # this is the final chunk\n            # if response.usage is not None:\n            #     request_final_usage = response.usage\n            #     continue\n            for choice in response.choices:\n                output_texts[choice.index] += choice.text\n                if choice.finish_reason is not None and finish_reasons[choice.index] is None:\n                    finish_reasons[choice.index] = choice.finish_reason\n                if choice.logprobs is not None:\n                    logprob_results[choice.index] = choice.logprobs\n\n        assert all(finish_reason is not None for finish_reason in finish_reasons)\n        return engine_base.wrap_completion_response(\n            request_id=request_id,\n            model=request.model,\n            output_texts=output_texts,\n            finish_reasons=finish_reasons,\n            logprob_results=logprob_results,\n            usage=request_final_usage,\n        )\n\n    # 2. Set up app\n    app = fastapi.FastAPI()\n    app.add_middleware(CORSMiddleware)\n    app.include_router(router_app)\n    app.exception_handler(error_protocol.BadRequestError)(error_protocol.bad_request_error_handler)\n\n    # 3. Run\n    uvicorn.run(app, host=router_host, port=router_port, log_level=\"info\")\n"
  },
  {
    "path": "python/mlc_llm/interface/serve.py",
    "content": "\"\"\"Python entrypoint of serve.\"\"\"\n\nfrom typing import Any, List, Literal, Optional, Tuple, Union\n\nimport fastapi\nimport uvicorn\nfrom fastapi.middleware.cors import CORSMiddleware\n\nfrom mlc_llm.protocol import error_protocol\nfrom mlc_llm.serve import engine\nfrom mlc_llm.serve.embedding_engine import AsyncEmbeddingEngine\nfrom mlc_llm.serve.entrypoints import (\n    debug_entrypoints,\n    metrics_entrypoints,\n    microserving_entrypoints,\n    openai_entrypoints,\n)\nfrom mlc_llm.serve.server import ServerContext\nfrom mlc_llm.support import logging\n\nlogger = logging.getLogger(__name__)\n\n\ndef serve(\n    model: str,\n    device: str,\n    model_lib: Optional[str],\n    mode: Literal[\"local\", \"interactive\", \"server\"],\n    enable_debug: bool,\n    additional_models: List[Union[str, Tuple[str, str]]],\n    embedding_model: Optional[str],\n    embedding_model_lib: Optional[str],\n    tensor_parallel_shards: Optional[int],\n    pipeline_parallel_stages: Optional[int],\n    opt: Optional[str],\n    max_num_sequence: Optional[int],\n    max_total_sequence_length: Optional[int],\n    max_single_sequence_length: Optional[int],\n    prefill_chunk_size: Optional[int],\n    sliding_window_size: Optional[int],\n    attention_sink_size: Optional[int],\n    max_history_size: Optional[int],\n    gpu_memory_utilization: Optional[float],\n    speculative_mode: Literal[\"disable\", \"small_draft\", \"eagle\", \"medusa\"],\n    spec_draft_length: Optional[int],\n    spec_tree_width: Optional[int],\n    prefix_cache_mode: Literal[\"disable\", \"radix\"],\n    prefix_cache_max_num_recycling_seqs: Optional[int],\n    prefill_mode: Literal[\"hybrid\", \"chunked\"],\n    enable_tracing: bool,\n    host: str,\n    port: int,\n    allow_credentials: bool,\n    allow_origins: Any,\n    allow_methods: Any,\n    allow_headers: Any,\n    api_key: Optional[str] = None,\n):  # pylint: disable=too-many-arguments, too-many-locals\n    \"\"\"Serve the model with the specified configuration.\"\"\"\n    # Create engine and start the background loop\n    async_engine = engine.AsyncMLCEngine(\n        model=model,\n        device=device,\n        model_lib=model_lib,\n        mode=mode,\n        engine_config=engine.EngineConfig(\n            additional_models=additional_models,\n            tensor_parallel_shards=tensor_parallel_shards,\n            pipeline_parallel_stages=pipeline_parallel_stages,\n            opt=opt,\n            max_num_sequence=max_num_sequence,\n            max_total_sequence_length=max_total_sequence_length,\n            max_single_sequence_length=max_single_sequence_length,\n            prefill_chunk_size=prefill_chunk_size,\n            sliding_window_size=sliding_window_size,\n            attention_sink_size=attention_sink_size,\n            max_history_size=max_history_size,\n            gpu_memory_utilization=gpu_memory_utilization,\n            speculative_mode=speculative_mode,\n            spec_draft_length=spec_draft_length,\n            spec_tree_width=spec_tree_width,\n            prefix_cache_mode=prefix_cache_mode,\n            prefix_cache_max_num_recycling_seqs=prefix_cache_max_num_recycling_seqs,\n            prefill_mode=prefill_mode,\n        ),\n        enable_tracing=enable_tracing,\n    )\n\n    # Set up embedding model if specified\n    emb_engine = None\n    if embedding_model is not None:\n        if embedding_model_lib is None:\n            raise ValueError(\n                \"--embedding-model-lib is required when --embedding-model is specified.\"\n            )\n        emb_engine = AsyncEmbeddingEngine(\n            model=embedding_model,\n            model_lib=embedding_model_lib,\n            device=device,\n        )\n        logger.info(\"Embedding model %s loaded successfully.\", embedding_model)\n\n    with ServerContext() as server_context:\n        server_context.add_model(model, async_engine)\n        if emb_engine is not None:\n            server_context.add_embedding_engine(embedding_model, emb_engine)\n        server_context.api_key = api_key\n\n        app = fastapi.FastAPI()\n        app.add_middleware(\n            CORSMiddleware,\n            allow_credentials=allow_credentials,\n            allow_origins=allow_origins,\n            allow_methods=allow_methods,\n            allow_headers=allow_headers,\n        )\n\n        app.include_router(openai_entrypoints.app)\n        app.include_router(metrics_entrypoints.app)\n        app.include_router(microserving_entrypoints.app)\n\n        server_context.enable_debug = enable_debug\n\n        if enable_debug:\n            app.include_router(debug_entrypoints.app)\n            logger.info(\"Enable debug endpoint and debug_config in requests...\")\n\n        app.exception_handler(error_protocol.BadRequestError)(\n            error_protocol.bad_request_error_handler\n        )\n        uvicorn.run(app, host=host, port=port, log_level=\"info\")\n"
  },
  {
    "path": "python/mlc_llm/json_ffi/__init__.py",
    "content": "\"\"\"JSON FFI is a pure string based interface of MLC LLM Engine.\n\nWe build interfacing with JSON FFI for both testing purposes\nand internal use. For most python API usage, please use MLCEngine\nand MLCAsyncEngine\n\"\"\"\n\nfrom .engine import JSONFFIEngine\n"
  },
  {
    "path": "python/mlc_llm/json_ffi/engine.py",
    "content": "# pylint: disable=chained-comparison,missing-docstring,too-few-public-methods,too-many-instance-attributes\n# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable\nimport json\nimport queue\nimport threading\nfrom typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Union\n\nimport tvm\n\nfrom mlc_llm.protocol import debug_protocol, openai_api_protocol\nfrom mlc_llm.serve import engine_utils\nfrom mlc_llm.serve.engine_base import (\n    EngineConfig,\n    EngineMetrics,\n    _check_engine_config,\n    _parse_models,\n    _process_model_args,\n    _query_engine_metrics,\n    detect_device,\n)\nfrom mlc_llm.tokenizers import Tokenizer\n\n\nclass EngineState:\n    sync_queue: queue.Queue\n\n    def get_request_stream_callback(self) -> Callable[[str], None]:\n        # ChatCompletionStreamResponse\n\n        def _callback(chat_completion_stream_responses_json_str: str) -> None:\n            self._sync_request_stream_callback(chat_completion_stream_responses_json_str)\n\n        return _callback\n\n    def _sync_request_stream_callback(self, chat_completion_stream_responses_json_str: str) -> None:\n        # Put the delta outputs to the queue in the unblocking way.\n        self.sync_queue.put_nowait(chat_completion_stream_responses_json_str)\n\n    def handle_chat_completion(\n        self, ffi: dict, request_json_str: str, include_usage: bool, request_id: str\n    ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]:\n        \"\"\"Helper class to handle chat completion\n\n        Note\n        ----\n        ffi is explicitly passed in to avoid cylic dependency\n        as ffi will capture EngineState\n        \"\"\"\n        self.sync_queue = queue.Queue()\n\n        success = bool(ffi[\"chat_completion\"](request_json_str, request_id))\n\n        try:\n            last_chunk_arrived = False\n            while not last_chunk_arrived:\n                chat_completion_responses_json_str = self.sync_queue.get()\n                chat_completion_responses_list = json.loads(chat_completion_responses_json_str)\n                for chat_completion_response_json_dict in chat_completion_responses_list:\n                    chat_completion_response = (\n                        openai_api_protocol.ChatCompletionStreamResponse.model_validate(\n                            chat_completion_response_json_dict\n                        )\n                    )\n                    # the chunk with usage is always the last chunk\n                    if chat_completion_response.usage is not None:\n                        if include_usage:\n                            yield chat_completion_response\n                        last_chunk_arrived = True\n                        break\n                    yield chat_completion_response\n        except Exception as exception:  # pylint: disable=broad-exception-caught\n            ffi[\"abort\"](request_id)\n            raise exception\n\n\nclass BackgroundLoops:\n    \"\"\"Helper class to keep track of background loops\"\"\"\n\n    def __init__(self, ffi: dict):\n        self._ffi = ffi\n        # important: avoid self reference in closure\n        background_loop = self._ffi[\"run_background_loop\"]\n        background_stream_back_loop = self._ffi[\"run_background_stream_back_loop\"]\n\n        # Create the background engine-driving thread and start the loop.\n        self._background_loop_thread: threading.Thread = threading.Thread(target=background_loop)\n        self._background_stream_back_loop_thread: threading.Thread = threading.Thread(\n            target=background_stream_back_loop\n        )\n        self._background_loop_thread.start()\n        self._background_stream_back_loop_thread.start()\n        self._terminated = False\n\n    def __del__(self):\n        self.terminate()\n\n    def terminate(self):\n        if self._terminated:\n            return\n        self._terminated = True\n        self._ffi[\"exit_background_loop\"]()\n        self._background_loop_thread.join()\n        self._background_stream_back_loop_thread.join()\n\n\nclass Completions:\n    \"\"\"Completions class to be compatible with OpenAI API\"\"\"\n\n    _ffi: dict\n    _state: EngineState\n    _background_loops: BackgroundLoops\n\n    def __init__(self, ffi: dict, state: EngineState, background_loops: BackgroundLoops):\n        self._ffi = ffi\n        self._state = state\n        self._background_loops = background_loops\n\n    def create(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        *,\n        messages: List[Dict[str, Any]],\n        model: str = None,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        logprobs: bool = False,\n        top_logprobs: int = 0,\n        logit_bias: Optional[Dict[int, float]] = None,\n        max_tokens: Optional[int] = None,\n        n: int = 1,\n        seed: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream: bool = True,\n        stream_options: Optional[Dict[str, Any]] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        tools: Optional[List[Dict[str, Any]]] = None,\n        tool_choice: Optional[Union[Literal[\"none\", \"auto\"], Dict]] = None,\n        user: Optional[str] = None,\n        response_format: Optional[Dict[str, Any]] = None,\n        request_id: Optional[str] = None,\n        extra_body: Optional[Dict[str, Any]] = None,\n    ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]:\n        if request_id is None:\n            request_id = f\"chatcmpl-{engine_utils.random_uuid()}\"\n        debug_config = extra_body.get(\"debug_config\", None) if extra_body is not None else None\n        if not stream:\n            raise ValueError(\"JSONFFIEngine only support stream=True\")\n        request = openai_api_protocol.ChatCompletionRequest(\n            messages=[\n                openai_api_protocol.ChatCompletionMessage.model_validate(message)\n                for message in messages\n            ],\n            model=model,\n            frequency_penalty=frequency_penalty,\n            presence_penalty=presence_penalty,\n            logprobs=logprobs,\n            top_logprobs=top_logprobs,\n            logit_bias=logit_bias,\n            max_tokens=max_tokens,\n            n=n,\n            seed=seed,\n            stop=stop,\n            stream=stream,\n            stream_options=(\n                openai_api_protocol.StreamOptions.model_validate(stream_options)\n                if stream_options is not None\n                else None\n            ),\n            temperature=temperature,\n            top_p=top_p,\n            tools=(\n                [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools]\n                if tools is not None\n                else None\n            ),\n            tool_choice=tool_choice,\n            user=user,\n            response_format=(\n                openai_api_protocol.RequestResponseFormat.model_validate(response_format)\n                if response_format is not None\n                else None\n            ),\n            debug_config=(\n                debug_protocol.DebugConfig.model_validate(debug_config)\n                if debug_config is not None\n                else None\n            ),\n        )\n        chatcmpl_generator = self._state.handle_chat_completion(\n            self._ffi,\n            request.model_dump_json(by_alias=True),\n            include_usage=(\n                request.stream_options is not None and request.stream_options.include_usage\n            ),\n            request_id=request_id,\n        )\n        for response in chatcmpl_generator:  # pylint: disable=use-yield-from\n            yield response\n\n\nclass Chat:\n    \"\"\"Chat class to be compatible with OpenAI API\"\"\"\n\n    completions: Completions\n\n    def __init__(self, ffi: dict, state: EngineState, background_loops: BackgroundLoops):\n        self.completions = Completions(ffi, state, background_loops)\n\n\nclass JSONFFIEngine:\n    chat: Chat\n\n    def __init__(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        model: str,\n        device: Union[str, tvm.runtime.Device] = \"auto\",\n        *,\n        model_lib: Optional[str] = None,\n        mode: Literal[\"local\", \"interactive\", \"server\"] = \"local\",\n        engine_config: Optional[EngineConfig] = None,\n    ) -> None:\n        # - Check the fields fields of `engine_config`.\n        if engine_config is None:\n            engine_config = EngineConfig()\n        _check_engine_config(model, model_lib, mode, engine_config)\n\n        # - Initialize model loading info.\n        models = _parse_models(model, model_lib, engine_config.additional_models)\n        if isinstance(device, str):\n            device = detect_device(device)\n        assert isinstance(device, tvm.runtime.Device)\n        model_args = _process_model_args(models, device, engine_config)[0]\n\n        # - Load the raw model config into dict\n        for i, model_info in enumerate(models):\n            model_info.model_lib = model_args[i][1]\n\n        # - Initialize engine state and engine.\n        self._state = EngineState()\n        module = tvm.get_global_func(\"mlc.json_ffi.CreateJSONFFIEngine\", allow_missing=False)()\n        self._ffi = {\n            key: module[key]\n            for key in [\n                \"init_background_engine\",\n                \"reload\",\n                \"unload\",\n                \"reset\",\n                \"chat_completion\",\n                \"abort\",\n                \"run_background_loop\",\n                \"run_background_stream_back_loop\",\n                \"exit_background_loop\",\n            ]\n        }\n        self.tokenizer = Tokenizer(model_args[0][0])\n        self._background_loops = BackgroundLoops(self._ffi)\n\n        engine_config.model = model_args[0][0]\n        engine_config.model_lib = model_args[0][1]\n        engine_config.additional_models = model_args[1:]  # type: ignore\n        engine_config.mode = mode\n        self.engine_config = engine_config\n\n        self._ffi[\"init_background_engine\"](\n            device.dlpack_device_type(),\n            device.index,\n            self._state.get_request_stream_callback(),\n        )\n        self._ffi[\"reload\"](self.engine_config.asjson())\n\n        self.chat = Chat(self._ffi, self._state, self._background_loops)\n\n    def metrics(self) -> EngineMetrics:\n        \"\"\"Get the engine metrics.\"\"\"\n        return _query_engine_metrics(self)\n\n    def _raw_chat_completion(\n        self, request_json_str: str, include_usage: bool, request_id: str\n    ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]:\n        \"\"\"Raw chat completion API\"\"\"\n        return self._state.handle_chat_completion(\n            self._ffi, request_json_str, include_usage, request_id\n        )\n\n    def terminate(self):\n        \"\"\"Explicitly terminate the engine\"\"\"\n        self._background_loops.terminate()\n\n    def _test_reload(self):\n        self._ffi[\"reload\"](self.engine_config.asjson())\n\n    def _test_reset(self):\n        self._ffi[\"reset\"]()\n\n    def _test_unload(self):\n        self._ffi[\"unload\"]()\n"
  },
  {
    "path": "python/mlc_llm/libinfo.py",
    "content": "\"\"\"Library information. This is a standalone file that can be used to get various info\"\"\"\n\n#! pylint: disable=protected-access\nimport os\nimport sys\n\n__version__ = \"0.1.dev0\"\nMLC_LIBRARY_PATH = os.environ.get(\"MLC_LIBRARY_PATH\", None)\n\n\ndef get_env_paths(env_var, splitter):\n    \"\"\"Get path in env variable\"\"\"\n    if os.environ.get(env_var, None):\n        return [p.strip() for p in os.environ[env_var].split(splitter)]\n    return []\n\n\ndef get_dll_directories():\n    \"\"\"Get extra mlc llm dll directories\"\"\"\n    curr_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__)))\n    source_dir = os.path.abspath(os.path.join(curr_dir, \"..\", \"..\"))\n    dll_path = [\n        curr_dir,\n        os.path.join(source_dir, \"build\"),\n        os.path.join(source_dir, \"build\", \"Release\"),\n    ]\n    if MLC_LIBRARY_PATH:\n        dll_path.append(MLC_LIBRARY_PATH)\n    if \"CONDA_PREFIX\" in os.environ:\n        dll_path.append(os.path.join(os.environ[\"CONDA_PREFIX\"], \"lib\"))\n    if sys.platform.startswith(\"linux\") or sys.platform.startswith(\"freebsd\"):\n        dll_path.extend(get_env_paths(\"LD_LIBRARY_PATH\", \":\"))\n    elif sys.platform.startswith(\"darwin\"):\n        dll_path.extend(get_env_paths(\"DYLD_LIBRARY_PATH\", \":\"))\n    elif sys.platform.startswith(\"win32\"):\n        dll_path.extend(get_env_paths(\"PATH\", \";\"))\n    return [os.path.abspath(p) for p in dll_path if os.path.isdir(p)]\n\n\ndef find_lib_path(name, optional=False):\n    \"\"\"Find mlc llm library\n\n    Parameters\n    ----------\n    name : str\n        The name of the library\n\n    optional: boolean\n        Whether the library is required\n    \"\"\"\n    if sys.platform.startswith(\"linux\") or sys.platform.startswith(\"freebsd\"):\n        lib_name = f\"lib{name}.so\"\n    elif sys.platform.startswith(\"win32\"):\n        lib_name = f\"{name}.dll\"\n    elif sys.platform.startswith(\"darwin\"):\n        lib_name = f\"lib{name}.dylib\"\n    else:\n        lib_name = f\"lib{name}.so\"\n\n    dll_paths = get_dll_directories()\n    lib_dll_path = [os.path.join(p, lib_name) for p in dll_paths]\n    lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)]\n    if not lib_found:\n        if not optional:\n            message = (\n                f\"Cannot find libraries: {lib_name}\\n\"\n                + \"List of candidates:\\n\"\n                + \"\\n\".join(lib_dll_path)\n            )\n            raise RuntimeError(message)\n    return lib_found\n"
  },
  {
    "path": "python/mlc_llm/loader/__init__.py",
    "content": "\"\"\"\nA subpackage of the compiler that represents mapping between external parameters, quantized\nparameters and parameters in MLC-defined models.\n\"\"\"\n\nfrom .huggingface_loader import HuggingFaceLoader\nfrom .loader import LOADER, Loader\nfrom .mapping import ExternMapping, QuantizeMapping\n"
  },
  {
    "path": "python/mlc_llm/loader/huggingface_loader.py",
    "content": "\"\"\"A weight loader for HuggingFace's PyTorch format\"\"\"\n\nimport gc\nimport json\nfrom collections import OrderedDict, defaultdict\nfrom pathlib import Path\nfrom typing import Callable, Dict, Iterator, List, Optional, Tuple\n\nimport numpy as np\nfrom tqdm import tqdm\nfrom tvm.runtime import Device, Tensor\nfrom tvm.runtime import tensor as as_tensor\n\nfrom mlc_llm.support import logging\nfrom mlc_llm.support.preshard import _sharded_param_name\nfrom mlc_llm.support.style import bold\n\nfrom .mapping import ExternMapping, QuantizeMapping\nfrom .stats import Stats\nfrom .utils import check_parameter_usage, load_safetensor_shard, load_torch_shard\n\nlogger = logging.getLogger(__name__)\n\n\nclass HuggingFaceLoader:  # pylint: disable=too-few-public-methods\n    \"\"\"A loader loading HuggingFace's PyTorch/SafeTensor format and converts them\n    to MLC's parameters.\n\n    Attributes\n    ----------\n    stats : Stats\n        Statistics of the loading process.\n\n    extern_param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch/SafeTensor.\n\n    torch_to_path : Dict[str, Path]\n        A mapping from PyTorch/SafeTensor parameter name to the path of the file containing it,\n        or the path meaning all parameters are stored in a single file.\n\n    cached_files : Dict[Path, Dict[str, np.ndarray]]\n        A cache of the loaded files. The key is the path of the file, and the value is a mapping\n        from parameter name to the parameter value.\n\n    quantize_param_map : Optional[QuantizeMapping]\n        The quantization mapping from MLC to quantized MLC parameters.\n    \"\"\"\n\n    stats: Stats\n    cached_files: Dict[Path, Dict[str, np.ndarray]]\n    torch_to_path: Dict[str, Path]\n    extern_param_map: ExternMapping\n    quantize_param_map: Optional[QuantizeMapping]\n\n    def __init__(\n        self,\n        path: Path,\n        extern_param_map: ExternMapping,\n        quantize_param_map: Optional[QuantizeMapping] = None,\n    ) -> None:\n        \"\"\"Create a parameter loader from HuggingFace PyTorch format.\n\n        Parameters\n        ----------\n        path : pathlib.Path\n            Path to either a JSON indexing file, or a PyTorch bin file.\n            1) For JSON indexing file, it is usually `pytorch_model.bin.index.json`\n            or `model.safetensors.index.json` in the repo, which contains a `weight_map` that\n            maps each PyTorch parameter to the file containing the weight.\n            2) For PyTorch bin file, it is usually `pytorch_model.bin` in the repo,\n            which contains all the parameters.\n            3) For safetensor file, it is usually `model.safetensors` in the repo,\n            which contains all the parameters.\n\n        extern_param_map : ExternMapping\n            Maps an MLC parameter to a list of PyTorch/SafeTensor parameters.\n\n        quantize_param_map: Optional[QuantizeMapping]\n            The quantization mapping from MLC to quantized MLC parameters, default to None, which\n            means no quantization.\n        \"\"\"\n        assert path.is_file(), f\"Path {path} is not a file\"\n        self.stats = Stats()\n        self.extern_param_map = extern_param_map\n        self.cached_files = {}\n        self.torch_to_path = {}\n        self.quantize_param_map = quantize_param_map\n        if path.suffix in (\".bin\", \".safetensors\", \".pt\"):\n            self._load_file(path)\n            for name in self.cached_files[path].keys():\n                self.torch_to_path[name] = path\n        elif path.suffix == \".json\":\n            with path.open(\"r\", encoding=\"utf-8\") as in_file:\n                torch_weight_map = json.load(in_file)[\"weight_map\"]\n            for torch_name, path_str in torch_weight_map.items():\n                self.torch_to_path[torch_name] = path.parent / path_str\n        else:\n            raise FileNotFoundError(f\"Unknown file suffix: {path}\")\n        check_parameter_usage(extern_param_map, set(self.torch_to_path.keys()))\n\n    def load(\n        self, device: Device, preshard_funcs: Dict[str, Callable] = None\n    ) -> Iterator[Tuple[str, Tensor]]:\n        \"\"\"Load the parameters and yield the MLC parameter and its value.\n\n        Parameters\n        ----------\n        device : Optional[Device]\n            The device to store the parameter, default to None, which means using CPU.\n\n        Yields\n        ------\n        Tuple[str, Tensor]\n            The MLC parameter name and its value, quantized if quantization mapping is provided.\n        \"\"\"\n        mlc_names = _loading_order(self.extern_param_map, self.torch_to_path)\n        for mlc_name in tqdm(mlc_names):\n            param = self._load_mlc_param(mlc_name, device=device)\n            # Apply quantization if needed, in this case the original parameter may become\n            # multiple quantized parameters.\n            for name, loader_param in self._load_or_quantize(mlc_name, param, device):\n                # Apply presharding if needed\n                if preshard_funcs is not None and name in preshard_funcs:\n                    for shard_id, shard_param in enumerate(preshard_funcs[name](loader_param)):\n                        yield _sharded_param_name(name, shard_id), shard_param\n                else:\n                    yield name, loader_param\n\n        cached_files = list(self.cached_files.keys())\n        for path in cached_files:\n            self._unload_file(path)\n        self.stats.log_time_info(\"HF\")\n        self.stats.log_mem_usage()\n\n    def _load_mlc_param(self, mlc_name: str, device: Optional[Device]) -> Tensor:\n        torch_names = self.extern_param_map.param_map[mlc_name]\n        files_required = {self.torch_to_path[p] for p in torch_names}\n        files_existing = set(self.cached_files.keys())\n        files_to_load = files_required - files_existing\n        files_to_unload = files_existing - files_required\n\n        # Step 1. When there is some file to unloaded:\n        # - If no pending file load: unloading is deferred as there is no gain in peak memory usage;\n        # - Need to load files: unload immediately to save memory and make space for the new files.\n        if files_to_load:\n            for path in files_to_unload:\n                self._unload_file(path)\n        # Step 2. Load all the files needed\n        for path in files_to_load:\n            self._load_file(path)\n        # Step 3. Collect all torch parameters in order\n        torch_params = [self.cached_files[self.torch_to_path[i]][i] for i in torch_names]\n        # Step 4. Apply the mapping function\n        with self.stats.timer(\"map_time_sec\"):\n            param = self.extern_param_map.map_func[mlc_name](*torch_params)\n        if device:\n            return as_tensor(param, device=device)\n        return as_tensor(param)\n\n    def _load_or_quantize(self, mlc_name, param, device: Device):\n        if self.quantize_param_map and mlc_name in self.quantize_param_map.param_map:\n            with self.stats.timer(\"quant_time_sec\"):\n                q_names = self.quantize_param_map.param_map[mlc_name]\n                q_params = self.quantize_param_map.map_func[mlc_name](param)\n                device.sync()\n            for q_name, q_param in zip(q_names, q_params):\n                logger.info(\n                    '[Quantized] Parameter: \"%s\", shape: %s, dtype: %s',\n                    bold(q_name),\n                    q_param.shape,\n                    q_param.dtype,\n                )\n                yield q_name, q_param\n        else:\n            logger.info(\n                '[Not quantized] Parameter: \"%s\", shape: %s, dtype: %s',\n                bold(mlc_name),\n                param.shape,\n                param.dtype,\n            )\n            device.sync()\n            yield mlc_name, param\n\n    def _load_file(self, path: Path) -> None:\n        logger.info(\"Loading HF parameters from: %s\", path)\n        load_func = load_safetensor_shard if path.suffix == \".safetensors\" else load_torch_shard\n        with self.stats.timer(\"load_time_sec\"):\n            result = {}\n            for name, param in load_func(path):\n                result[name] = param\n                self.stats.mem_add(param.nbytes)\n                if name not in self.extern_param_map.unused_params:\n                    self.stats.total_param_num += param.size\n            self.cached_files[path] = result\n\n    def _unload_file(self, path: Path) -> None:\n        logger.info(\"Unloading HF weight file: %s\", path)\n        with self.stats.timer(\"load_time_sec\"):\n            for _, param in self.cached_files[path].items():\n                self.stats.mem_rm(param.nbytes)\n            del self.cached_files[path]\n            gc.collect()\n\n\ndef _loading_order(param_map: ExternMapping, torch_to_path: Dict[str, Path]) -> List[str]:\n    # Step 1. Build a map from path to torch parameters\n    path_to_torch: Dict[Path, List[str]] = defaultdict(list)\n    for torch_name, path in torch_to_path.items():\n        path_to_torch[path].append(torch_name)\n    # Step 2. Build a map from torch parameters to MLC parameters\n    torch_to_mlc = defaultdict(list)\n    for mlc_name, torch_names in param_map.param_map.items():\n        for torch_name in torch_names:\n            torch_to_mlc[torch_name].append(mlc_name)\n    # Step 3. Construct the ordering that ensures file locality\n    order = OrderedDict()\n    for _, torch_names in path_to_torch.items():\n        for torch_name in torch_names:\n            for mlc_name in torch_to_mlc[torch_name]:\n                if mlc_name not in order:\n                    order[mlc_name] = 1\n    return list(order.keys())\n\n\n__all__ = [\"HuggingFaceLoader\"]\n"
  },
  {
    "path": "python/mlc_llm/loader/loader.py",
    "content": "\"\"\"A centralized registry of all existing loaders.\"\"\"\n\nfrom typing import Any, Dict\n\nfrom .huggingface_loader import HuggingFaceLoader\n\nLoader = Any\n\nLOADER: Dict[str, Any] = {\n    \"huggingface-torch\": HuggingFaceLoader,\n    \"huggingface-safetensor\": HuggingFaceLoader,\n    \"awq\": HuggingFaceLoader,\n}\n"
  },
  {
    "path": "python/mlc_llm/loader/mapping.py",
    "content": "\"\"\"Parameter mapping for converting different LLM implementations to MLC LLM.\"\"\"\n\nimport dataclasses\nfrom typing import Callable, Dict, List, Set, Union\n\nimport numpy as np\nfrom tvm.runtime import Tensor\n\nMapFuncVariadic = Union[\n    Callable[[], np.ndarray],\n    Callable[[np.ndarray], np.ndarray],\n    Callable[[np.ndarray, np.ndarray], np.ndarray],\n    Callable[[np.ndarray, np.ndarray, np.ndarray], np.ndarray],\n    Callable[[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray],\n]\n\n\n@dataclasses.dataclass\nclass ExternMapping:\n    \"\"\"Mapping from a parameter name in MLC LLM's model definition to its potential source,\n    for example, from MLC parameter \"model.layers.2.post_attention_layernorm.weight\" to PyTorch's\n    parameter correspondingly.\n\n    Parameters\n    ----------\n    param_map : Dict[str, List[str]]\n        A dictionary that maps the name of a parameter to its source. For example,\n        in Llama2, the source of MLC parameter \"model.layers.0.self_attn.qkv_proj.weight\" from\n        huggingface torch are:\n\n        - \"model.layers.0.self_attn.q_proj.weight\"\n        - \"model.layers.0.self_attn.k_proj.weight\"\n        - \"model.layers.0.self_attn.v_proj.weight\"\n\n    map_func : Dict[str, Callable[[np.ndarray, ...], np.ndarray]]\n        A dictionary that maps the name of a parameter to a function that combines the source\n        parameters into the MLC parameter. For example, for the above example, the function\n        would be: `lambda q, k, v: np.concatenate([q, k, v], axis=0)`.\n\n    unused_params : Set[str]\n        Parameter names in the source weights that are not used in the MLC LLM model definition.\n    \"\"\"\n\n    param_map: Dict[str, List[str]] = dataclasses.field(default_factory=dict)\n    map_func: Dict[str, MapFuncVariadic] = dataclasses.field(default_factory=dict)\n    unused_params: Set[str] = dataclasses.field(default_factory=set)\n\n    def add_mapping(\n        self,\n        map_from: str,\n        map_to: List[str],\n        func: MapFuncVariadic,\n    ) -> None:\n        \"\"\"Add a mapping from MLC parameters to source parametes as well as a mapping function.\"\"\"\n        self.param_map[map_from] = map_to\n        self.map_func[map_from] = func\n\n    def add_unused(self, name: str):\n        \"\"\"Add a parameter name in the source parameters to the set of unused parameters.\"\"\"\n        self.unused_params.add(name)\n\n\n@dataclasses.dataclass\nclass QuantizeMapping:\n    \"\"\"Mapping from a parameter in MLC LLM's model definition to its eventual names and values after\n    quantization. In certain group quantization, for example, `qkv_proj.weight` is mapped to\n    `qkv_proj.weight_quantized` and `qkv_proj.weight_scale` respectively. If a parameter's name is\n    not in the mapping, it is assumed to be unchanged, i.e. not quantized.\n\n    Parameters\n    ----------\n    param_map : Dict[str, List[str]]\n        A dictionary that maps the name of a parameter to its destination. For example,\n        in certain group quantization, the destinations of MLC parameter \"qkv_proj.weight` are:\n\n        - \"qkv_proj.weight_quantized\"\n        - \"qkv_proj.weight_scale\"\n\n    map_func : Dict[str, Callable[Tensor, List[Tensor]]]\n        A dictionary that maps the name of a parameter to a function that splits the MLC parameter\n        into the destination parameters.\n\n    Notes\n    -----\n    There are two forms of weight conversion in MLC LLM, one is A) on-the-fly quantization to the\n    raw fp16/bf16/fp32 weights from HuggingFace, and the other is B) loading pre-quantized weights\n    from an external framework, e.g. AutoGPTQ, AutoAWQ. From the perspective of parameter\n    correspondence.\n\n    - In case A), it is recommended that the weight loader take both `ExternMapping` and\n    `QuantizeMapping` as input, and do quantiaztion on the fly as a raw parameter being\n    loaded into RAM;\n    - In case B), a pass over `nn.Module` is recommended to take place first to converts parameters\n    from its non-quantized form to the quantized one, and then only `ExternMapping` is\n    used to convert the quantized parameters into the desired form.\n    \"\"\"\n\n    param_map: Dict[str, List[str]]\n    map_func: Dict[str, Callable[[Tensor], List[Tensor]]]\n\n\n__all__ = [\"ExternMapping\", \"QuantizeMapping\"]\n"
  },
  {
    "path": "python/mlc_llm/loader/standard_loader.py",
    "content": "\"\"\"Standard HuggingFace loader mapping helpers.\"\"\"\n\nfrom __future__ import annotations\n\nimport functools\nfrom typing import Callable, Iterable, Optional, Sequence, Type\n\nimport numpy as np\nfrom tvm.relax.frontend import nn  # type: ignore[import]\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.quantization import Quantization\n\nNameTransform = Callable[[str], str]\nExportSpecGetter = Callable[[nn.Module], object]\n\n\ndef _default_export_spec(model: nn.Module) -> object:\n    return model.get_default_spec()\n\n\ndef make_standard_hf_loader(  # pylint: disable=too-many-arguments,too-many-locals\n    *,\n    model_cls: Type[nn.Module],\n    layer_prefix: str = \"model.layers\",\n    qkv_names: Sequence[str] = (\"q_proj\", \"k_proj\", \"v_proj\"),\n    qkv_concat_axis: int = 0,\n    qkv_target_name: str = \"qkv_proj\",\n    add_qkv_bias: bool = False,\n    qkv_bias_optional: bool = False,\n    gate_up_names: Sequence[str] = (\"gate_proj\", \"up_proj\"),\n    gate_up_concat_axis: int = 0,\n    gate_up_target_name: str = \"gate_up_proj\",\n    include_qkv: bool = True,\n    include_gate_up: bool = True,\n    add_unused: Optional[Iterable[str]] = None,\n    hf_prefix: str = \"model.\",\n    name_transform: Optional[NameTransform] = None,\n    export_spec_getter: Optional[ExportSpecGetter] = None,\n    num_layers_getter: Optional[Callable[[object], int]] = None,\n) -> Callable[[object, Quantization], ExternMapping]:\n    \"\"\"Create a standard loader for HuggingFace weights.\n\n    This handles the common QKV concatenation, gate+up concatenation, optional\n    QKV bias mapping, and passes through remaining parameters 1:1.\n    \"\"\"\n\n    if not qkv_names:\n        include_qkv = False\n    if not gate_up_names:\n        include_gate_up = False\n    if not include_qkv:\n        qkv_names = ()\n    if not include_gate_up:\n        gate_up_names = ()\n\n    def _default_name_transform(name: str) -> str:\n        # When hf_prefix is empty, strip the \"model.\" prefix so models that\n        # expose bare top-level weights (no \"model.\" namespace) still load.\n        if hf_prefix == \"\":\n            return name[6:] if name.startswith(\"model.\") else name\n        return name\n\n    name_transform_fn = name_transform or _default_name_transform\n    spec_getter = export_spec_getter or _default_export_spec\n    unused_names = tuple(add_unused or ())\n\n    def huggingface(  # pylint: disable=too-many-locals,too-many-branches\n        model_config: object,\n        quantization: Quantization,\n    ) -> ExternMapping:\n        model = model_cls(model_config)\n        if quantization is not None:\n            model.to(quantization.model_dtype)\n        _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n            spec=spec_getter(model),\n            allow_extern=True,\n        )\n        named_parameters = dict(_named_params)\n        mapping = ExternMapping()\n\n        if include_qkv or include_gate_up or unused_names:\n            if num_layers_getter is None:\n                num_layers = model_config.num_hidden_layers  # type: ignore[attr-defined]\n            else:\n                num_layers = num_layers_getter(model_config)\n\n            for i in range(num_layers):\n                attn = f\"{layer_prefix}.{i}.self_attn\"\n                if include_qkv:\n                    mlc_qkv_name = f\"{attn}.{qkv_target_name}.weight\"\n                    mlc_param = named_parameters[mlc_qkv_name]\n                    mapping.add_mapping(\n                        mlc_qkv_name,\n                        [name_transform_fn(f\"{attn}.{name}.weight\") for name in qkv_names],\n                        functools.partial(\n                            lambda q, k, v, dtype: np.concatenate(\n                                [q, k, v], axis=qkv_concat_axis\n                            ).astype(dtype),\n                            dtype=mlc_param.dtype,\n                        ),\n                    )\n\n                    if add_qkv_bias:\n                        mlc_bias_name = f\"{attn}.{qkv_target_name}.bias\"\n                        if (not qkv_bias_optional) or mlc_bias_name in named_parameters:\n                            mlc_param = named_parameters[mlc_bias_name]\n                            mapping.add_mapping(\n                                mlc_bias_name,\n                                [name_transform_fn(f\"{attn}.{name}.bias\") for name in qkv_names],\n                                functools.partial(\n                                    lambda q, k, v, dtype: np.concatenate(\n                                        [q, k, v], axis=qkv_concat_axis\n                                    ).astype(dtype),\n                                    dtype=mlc_param.dtype,\n                                ),\n                            )\n\n                if include_gate_up:\n                    mlp = f\"{layer_prefix}.{i}.mlp\"\n                    mlc_gate_up_name = f\"{mlp}.{gate_up_target_name}.weight\"\n                    if gate_up_names:\n                        mlc_param = named_parameters[mlc_gate_up_name]\n                        mapping.add_mapping(\n                            mlc_gate_up_name,\n                            [name_transform_fn(f\"{mlp}.{name}.weight\") for name in gate_up_names],\n                            functools.partial(\n                                lambda gate, up, dtype: np.concatenate(\n                                    [gate, up], axis=gate_up_concat_axis\n                                ).astype(dtype),\n                                dtype=mlc_param.dtype,\n                            ),\n                        )\n\n                for unused_name in unused_names:\n                    mapping.add_unused(name_transform_fn(f\"{attn}.{unused_name}\"))\n\n        for mlc_name, mlc_param in named_parameters.items():\n            if mlc_name not in mapping.param_map:\n                mapping.add_mapping(\n                    mlc_name,\n                    [name_transform_fn(mlc_name)],\n                    functools.partial(\n                        lambda x, dtype: x.astype(dtype),\n                        dtype=mlc_param.dtype,\n                    ),\n                )\n\n        return mapping\n\n    return huggingface\n"
  },
  {
    "path": "python/mlc_llm/loader/stats.py",
    "content": "\"\"\"Statistics of the loading process of parameter loaders\"\"\"\n\nimport dataclasses\nimport time\nfrom contextlib import contextmanager\n\nfrom mlc_llm.support import logging\nfrom mlc_llm.support.style import green\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass Stats:\n    \"\"\"Statistics of the loading process of parameter loaders.\n\n    Attributes\n    ----------\n    load_time_sec : float\n        Time used in loading the parameters.\n\n    map_time_sec : float\n        Time used in applying the mapping function, i.e. `ExternMapping.map_func`.\n\n    quant_time_sec : float\n        Time used in quantizing the parameters, i.e. `QuantizeMapping.quant_func`.\n\n    current_memory_gb : float\n        The current RAM usage in GB.\n\n    total_memory_gb : float\n        The total size data loaded from disk in GB.\n\n    max_memory_gb : float\n        The maximum RAM usage in GB.\n\n    total_param_num: int\n        Total number of parameters (original non-MLC model weights), excluding unused params.\n    \"\"\"\n\n    load_time_sec: float = 0.0\n    map_time_sec: float = 0.0\n    quant_time_sec: float = 0.0\n\n    current_memory_gb: float = 0.0\n    total_memory_gb: float = 0.0\n    max_memory_gb: float = 0.0\n\n    total_param_num: int = 0\n\n    def timer(self, attr):\n        \"\"\"A context manager to time the scope and add the time to the attribute.\"\"\"\n\n        @contextmanager\n        def timed_scope():\n            start_time = time.time()\n            yield\n            elapsed_time = time.time() - start_time\n            setattr(self, attr, getattr(self, attr) + elapsed_time)\n\n        return timed_scope()\n\n    def mem_add(self, nbytes: int):\n        \"\"\"Add the memory usage by the given number of bytes.\"\"\"\n        mem_gb = float(nbytes) / float(1024**3)\n        self.current_memory_gb += mem_gb\n        self.total_memory_gb += mem_gb\n        self.max_memory_gb = max(self.max_memory_gb, self.current_memory_gb)\n\n    def mem_rm(self, nbytes: int):\n        \"\"\"Remove the memory usage by the given number of bytes.\"\"\"\n        mem_gb = float(nbytes) / float(1024**3)\n        self.current_memory_gb -= mem_gb\n\n    def log_time_info(self, weight_format: str):\n        \"\"\"Log the time used in loading, pre-quantization and quantization.\"\"\"\n        logger.info(\n            \"%s: \"\n            \"%s loading: %.3f sec; \"\n            \"Pre-quantization mapping: %.3f sec; \"\n            \"Quantization: %.3f sec\",\n            green(\"Time usage\"),\n            weight_format,\n            self.load_time_sec,\n            self.map_time_sec,\n            self.quant_time_sec,\n        )\n\n    def log_mem_usage(self):\n        \"\"\"Log the Memory usage information.\"\"\"\n        logger.info(\n            \"%s: Peak RAM: %.3f GB. Total bytes loaded from disk: %.3f GB\",\n            green(\"RAM usage\"),\n            self.max_memory_gb,\n            self.total_memory_gb,\n        )\n"
  },
  {
    "path": "python/mlc_llm/loader/utils.py",
    "content": "\"\"\"Common utilities for loading parameters\"\"\"\n\n# pylint: disable=too-few-public-methods\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Iterator, Set, Tuple\n\nimport numpy as np\n\nfrom mlc_llm.support import logging\n\nif TYPE_CHECKING:\n    from tvm.runtime import Tensor\n\n    from .mapping import ExternMapping\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef check_parameter_usage(param_map: \"ExternMapping\", extern_weights: Set[str]):\n    \"\"\"Check that all external parameters have been used and are stored in the weights file.\"\"\"\n    used_extern_names = set(sum(param_map.param_map.values(), []))\n    # Check 1. All extern parameters in the weight files are used unless explicitly specified\n    unused_extern_names = extern_weights - used_extern_names - param_map.unused_params\n    if unused_extern_names:\n        logger.warning(\n            \"Unused extern parameters: %s\",\n            \", \".join(sorted(unused_extern_names)),\n        )\n    # Check 2. All extern parameters required are stored in the weight files\n    nonexistent_extern_names = used_extern_names - extern_weights\n    if nonexistent_extern_names:\n        raise ValueError(\n            \"The following extern parameters do not exist in the weight files:\\n  \"\n            + \"\\n  \".join(sorted(nonexistent_extern_names)),\n        )\n\n\ndef load_torch_shard(path: Path) -> Iterator[Tuple[str, np.ndarray]]:\n    \"\"\"Load and yield PyTorch format parameters.\"\"\"\n    import torch  # pylint: disable=import-outside-toplevel\n\n    for name, param in torch.load(path, map_location=torch.device(\"cpu\")).items():\n        if param is None:\n            logger.warning(\"Encountered None param, skipping it: %s\", name)\n            continue\n        param = param.detach().cpu()\n        dtype = str(param.dtype)\n        if dtype == \"torch.bfloat16\":\n            param = param.float()\n        param = param.numpy()\n        yield name, param\n\n\ndef load_safetensor_shard(path: Path) -> Iterator[Tuple[str, np.ndarray]]:\n    \"\"\"Load and yield SafeTensor format parameters.\"\"\"\n    import safetensors  # pylint: disable=import-outside-toplevel,import-error\n    import torch  # pylint: disable=import-outside-toplevel\n\n    with safetensors.safe_open(path, framework=\"pt\", device=\"cpu\") as in_file:\n        for name in in_file.keys():\n            param = in_file.get_tensor(name)\n            param = param.detach().cpu()\n            dtype = str(param.dtype)\n            if dtype == \"torch.bfloat16\":\n                import ml_dtypes  # pylint: disable=import-outside-toplevel\n\n                param = param.view(torch.float16).cpu().numpy().view(ml_dtypes.bfloat16)\n            elif dtype == \"torch.float8_e4m3fn\":\n                import ml_dtypes  # pylint: disable=import-outside-toplevel\n\n                param = param.view(torch.uint8).cpu().numpy().view(ml_dtypes.float8_e4m3fn)\n            else:\n                param = param.numpy()\n            yield name, param\n"
  },
  {
    "path": "python/mlc_llm/model/__init__.py",
    "content": "\"\"\"Model definition for the compiler.\"\"\"\n\nfrom .model import MODELS, Model\nfrom .model_preset import MODEL_PRESETS\n"
  },
  {
    "path": "python/mlc_llm/model/baichuan/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/baichuan/baichuan_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's BaichuanLM parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\n\nfrom .baichuan_model import BaichuanForCausalLM\n\nhuggingface = make_standard_hf_loader(\n    model_cls=BaichuanForCausalLM,\n    include_qkv=False,\n)\n"
  },
  {
    "path": "python/mlc_llm/model/baichuan/baichuan_model.py",
    "content": "\"\"\"\nImplementation for BAICHUAN architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass BaichuanConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Baichuan model.\"\"\"\n\n    vocab_size: int\n    hidden_size: int\n    num_hidden_layers: int\n    num_attention_heads: int\n    initializer_range: float\n    intermediate_size: int\n    rms_norm_eps: float\n    use_cache: bool\n    pad_token_id: int\n    bos_token_id: int\n    eos_token_id: int\n    tie_word_embeddings: bool\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    max_batch_size: int = 1\n    head_dim: int = 0\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.head_dim * self.num_attention_heads == self.hidden_size\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass BaichuanAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: BaichuanConfig):\n        self.hidden_size = config.hidden_size\n        if config.num_attention_heads % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.num_attention_heads} attention heads \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.num_heads = config.num_attention_heads // config.tensor_parallel_shards\n        self.head_dim = config.head_dim\n        self.W_pack = nn.Linear(self.hidden_size, 3 * self.num_heads * self.head_dim, bias=False)\n        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h = self.head_dim, self.num_heads\n        b, s, _ = hidden_states.shape\n        qkv = self.W_pack(hidden_states)\n        qkv = op.reshape(qkv, (b, s, 3 * h, d))\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h * d),\n        )\n        attn_output = self.o_proj(output)\n        return attn_output\n\n\nclass BaichuanMLP(nn.Module):\n    def __init__(self, config: BaichuanConfig):\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.gate_up_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=2 * self.intermediate_size,\n            bias=False,\n        )\n        self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)\n\n    def forward(self, x):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(op.silu(x1) * x2)\n\n\nclass BaichuanDecoderLayer(nn.Module):\n    def __init__(self, config: BaichuanConfig):\n        norm_eps = config.rms_norm_eps\n        self.self_attn = BaichuanAttention(config=config)\n        self.mlp = BaichuanMLP(config)\n        self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, norm_eps, bias=False)\n        self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, norm_eps, bias=False)\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_heads * hd\n            k = self.self_attn.num_heads * hd\n            v = self.self_attn.num_heads * hd\n            i = self.mlp.intermediate_size\n            _set(\n                self.self_attn.W_pack.weight,\n                tp.ShardSingleDim(\"_shard_qkv_weight\", dim=0, segs=[q, k, v]),\n            )\n            _set(self.self_attn.o_proj.weight, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.gate_up_proj.weight,\n                tp.ShardSingleDim(\"_shard_mlp_gate_up\", segs=[i, i], dim=0),\n            )\n            _set(\n                self.mlp.down_proj.weight,\n                tp.ShardSingleDim(\"_shard_mlp_down_proj\", dim=1),\n            )\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.mlp(self.post_attention_layernorm(hidden_states))\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass BaichuanModel(nn.Module):\n    def __init__(self, config: BaichuanConfig):\n        assert config.hidden_size % config.num_attention_heads == 0\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)\n        self.layers = nn.ModuleList(\n            [BaichuanDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n\n    def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = inputs\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass BaichuanForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: BaichuanConfig):\n        self.model = BaichuanModel(config)\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.vocab_size = config.vocab_size\n        self.num_hidden_layers = config.num_hidden_layers\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.num_attention_heads\n        self.head_dim = config.head_dim\n        self.vocab_size = config.vocab_size\n        self.rope_theta = 10000\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/bert/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/bert/bert_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's BERT parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\nfrom typing import Literal\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.quantization import Quantization\n\nfrom .bert_model import BertConfig, BertModel\n\n\ndef huggingface(\n    model_config: BertConfig,\n    quantization: Quantization,\n    hf_prefix: Literal[\"\", \"bert.\"] = \"\",\n) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : BertConfig\n        The configuration of the BERT model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    hf_prefix : Literal[\"\", \"bert.\"]\n        Prefix used in HuggingFace weight names. Defaults to \"\" for standard\n        BERT models. Use \"bert.\" for BGE models whose weights are prefixed.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = BertModel(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    def to_hf(name: str) -> str:\n        return f\"{hf_prefix}{name}\" if hf_prefix else name\n\n    for i in range(model_config.num_hidden_layers):\n        attn = f\"encoder.layer.{i}.attention.self\"\n        mlc_name = f\"{attn}.qkv.weight\"\n        mlc_param = named_parameters[mlc_name]\n        mapping.add_mapping(\n            mlc_name,\n            [\n                to_hf(f\"{attn}.query.weight\"),\n                to_hf(f\"{attn}.key.weight\"),\n                to_hf(f\"{attn}.value.weight\"),\n            ],\n            functools.partial(\n                lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n        mlc_name = f\"{attn}.qkv.bias\"\n        mlc_param = named_parameters[mlc_name]\n        mapping.add_mapping(\n            mlc_name,\n            [\n                to_hf(f\"{attn}.query.bias\"),\n                to_hf(f\"{attn}.key.bias\"),\n                to_hf(f\"{attn}.value.bias\"),\n            ],\n            functools.partial(\n                lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [to_hf(mlc_name)],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n    # Mark unused weights that exist in HF but not in MLC\n    if hf_prefix:\n        mapping.add_unused(f\"{hf_prefix}pooler.dense.weight\")\n        mapping.add_unused(f\"{hf_prefix}pooler.dense.bias\")\n\n    return mapping\n\n\ndef huggingface_bge(model_config: BertConfig, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping for BGE models.\n\n    BGE weights have no prefix but include extra unused weights:\n    pooler.dense.weight, pooler.dense.bias, embeddings.position_ids\n    \"\"\"\n    mapping = huggingface(model_config, quantization, \"\")\n    mapping.add_unused(\"pooler.dense.weight\")\n    mapping.add_unused(\"pooler.dense.bias\")\n    mapping.add_unused(\"embeddings.position_ids\")\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/bert/bert_model.py",
    "content": "\"\"\"\nImplementation for BERT architecture.\n\"\"\"\n\nimport dataclasses\nfrom functools import partial\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.support import logging\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass BertConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the BERT model.\"\"\"\n\n    vocab_size: int\n    hidden_size: int\n    num_hidden_layers: int\n    num_attention_heads: int\n    intermediate_size: int\n    hidden_act: str\n    layer_norm_eps: float\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    type_vocab_size: int = 2\n    pad_token_id: int = 0\n    position_offset: int = 0\n    head_dim: int = 0\n    max_batch_size: int = 1\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.intermediate_size is None or self.intermediate_size == -1:\n            self.intermediate_size = 4 * self.hidden_size\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.head_dim * self.num_attention_heads == self.hidden_size\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %s (%d)\",\n                bold(\"prefill_chunk_size\"),\n                bold(\"context_window_size\"),\n                self.context_window_size,\n            )\n            self.prefill_chunk_size = self.context_window_size\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d (%s)\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                self.context_window_size,\n                bold(\"context_window_size\"),\n            )\n            self.prefill_chunk_size = self.context_window_size\n\n\n# pylint: disable=invalid-name,missing-docstring,too-many-locals\n\n\nclass BertSelfAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: BertConfig):\n        if config.num_attention_heads % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.num_attention_heads} attention heads\"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.num_heads = config.num_attention_heads // config.tensor_parallel_shards\n        self.head_dim = config.head_dim\n\n        self.qkv = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=3 * self.num_heads * self.head_dim,\n            bias=True,\n        )\n\n    def forward(self, hidden_states: Tensor, attention_mask: Tensor):\n        d, h = self.head_dim, self.num_heads\n        b, s, _ = hidden_states.shape\n\n        qkv = self.qkv(hidden_states)\n        qkv = op.reshape(qkv, (b, s, 3 * h, d))\n        q, k, v = op.split(qkv, 3, axis=2)\n\n        # Attention\n        output = op_ext.attention(q, k, v, attention_mask)\n        return output\n\n\nclass BertSelfOutput(nn.Module):\n    def __init__(self, config: BertConfig):\n        self.dense = nn.Linear(config.hidden_size, config.hidden_size)\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n\n    def forward(self, hidden_states: Tensor, input_tensor: Tensor):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states\n\n\nclass BertAttention(nn.Module):\n    def __init__(self, config: BertConfig):\n        self.self = BertSelfAttention(config)\n        self.output = BertSelfOutput(config)\n\n    def forward(self, hidden_states: Tensor, attention_mask: Tensor):\n        self_output = self.self(hidden_states, attention_mask)\n        attention_output = self.output(self_output, hidden_states)\n        return attention_output\n\n\nACT2FN = {\n    \"gelu\": partial(nn.gelu, approximate=False),\n    \"relu\": nn.relu,\n    \"silu\": nn.silu,\n    \"swish\": nn.silu,\n    \"gelu_new\": partial(nn.gelu, approximate=True),\n}\n\n\nclass BertIntermediate(nn.Module):\n    def __init__(self, config: BertConfig):\n        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)\n        self.intermediate_act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, hidden_states: Tensor):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.intermediate_act_fn(hidden_states)\n        return hidden_states\n\n\nclass BertOutput(nn.Module):\n    def __init__(self, config: BertConfig):\n        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n\n    def forward(self, hidden_states: Tensor, input_tensor: Tensor):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states\n\n\nclass BertLayer(nn.Module):\n    def __init__(self, config: BertConfig):\n        self.attention = BertAttention(config)\n        self.intermediate = BertIntermediate(config)\n        self.output = BertOutput(config)\n\n    def forward(self, hidden_states: Tensor, attention_mask: Tensor):\n        attention_output = self.attention(hidden_states, attention_mask)\n        intermediate_output = self.intermediate(attention_output)\n        layer_output = self.output(intermediate_output, attention_output)\n        return layer_output\n\n\nclass BertEncoder(nn.Module):\n    def __init__(self, config: BertConfig):\n        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])\n\n    def forward(self, hidden_states: Tensor, attention_mask: Tensor):\n        for layer in self.layer:\n            hidden_states = layer(hidden_states, attention_mask)\n        return hidden_states\n\n\nclass BertEmbeddings(nn.Module):\n    def __init__(self, config: BertConfig):\n        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, dtype=\"float32\")\n        self.position_embeddings = nn.Embedding(\n            config.context_window_size, config.hidden_size, dtype=\"float32\"\n        )\n        self.token_type_embeddings = nn.Embedding(\n            config.type_vocab_size, config.hidden_size, dtype=\"float32\"\n        )\n        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n\n    def forward(self, input_ids: Tensor, token_type_ids: Tensor, position_ids: Tensor):\n        words_embeddings = self.word_embeddings(input_ids)\n        position_embeddings = self.position_embeddings(position_ids)\n        token_type_embeddings = self.token_type_embeddings(token_type_ids)\n\n        embeddings = words_embeddings + position_embeddings + token_type_embeddings\n        embeddings = self.LayerNorm(embeddings)\n        return embeddings\n\n\nclass BertModel(nn.Module):\n    def __init__(self, config: BertConfig):\n        self.embeddings = BertEmbeddings(config)\n        self.encoder = BertEncoder(config)\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def forward(self, inputs: Tensor, attention_mask: Tensor):\n        # TODO: XLM-RoBERTa models use position indices starting from pad_token_id + 1  # pylint: disable=fixme\n        # (e.g., [2, 3, 4, ...] when pad_token_id=1), while this implementation uses\n        # [0, 1, 2, ...]. For XLM-RoBERTa models (e.g., bge-m3), the position_embeddings\n        # weights need to be shifted during weight conversion to compensate.\n        def _input_positions(inputs: te.Tensor):\n            b, s = inputs.shape\n            return te.compute((b, s), lambda _, j: j.astype(\"int32\"), name=\"input_positions\")\n\n        input_positions = op.tensor_expr_op(\n            _input_positions,\n            name_hint=\"input_positions\",\n            args=[inputs],\n        )\n\n        token_type_ids = op.zeros(inputs.shape, dtype=\"int32\")\n\n        embeddings = self.embeddings(inputs, token_type_ids, input_positions)\n        encoder_output = self.encoder(embeddings, attention_mask)\n        return encoder_output\n\n    def prefill(self, inputs: Tensor, attention_mask: Tensor):\n        def _attention_mask(mask: te.Tensor, zero, batch_size, seq_len):\n            return te.compute(\n                (batch_size, 1, seq_len, seq_len),\n                lambda b, _, i, j: tir.if_then_else(\n                    tir.any(mask[b, i] == zero, mask[b, j] == zero),\n                    tir.min_value(self.dtype),\n                    tir.max_value(self.dtype),\n                ),\n                name=\"attention_mask_prefill\",\n            )\n\n        batch_size, seq_len = inputs.shape\n        attention_mask_2d = op.tensor_expr_op(\n            _attention_mask,\n            name_hint=\"attention_mask_prefill\",\n            args=[attention_mask, tir.IntImm(\"int32\", 0), batch_size, seq_len],\n        )\n        return self.forward(inputs, attention_mask_2d)\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"prefill\": {\n                \"inputs\": nn.spec.Tensor([\"batch_size\", \"seq_len\"], \"int32\"),\n                \"attention_mask\": nn.spec.Tensor([\"batch_size\", \"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/chatglm3/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/chatglm3/chatglm3_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's ChatGLM3 parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.quantization import Quantization\n\nfrom .chatglm3_model import ChatGLMForCausalLM, GLMConfig\n\n\ndef huggingface(model_config: GLMConfig, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : GLMConfig\n        The configuration of the Baichuan model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = ChatGLMForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    mlc_name = \"transformer.embedding.weight\"\n    mlc_param = named_parameters[mlc_name]\n    mapping.add_mapping(\n        mlc_name,\n        [\"transformer.embedding.word_embeddings.weight\"],\n        functools.partial(\n            lambda x, dtype: x.astype(dtype),\n            dtype=mlc_param.dtype,\n        ),\n    )\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/chatglm3/chatglm3_model.py",
    "content": "\"\"\"\nImplementation for CHATGLM3 architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass GLMConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the ChatGLM model.\"\"\"\n\n    hidden_size: int\n    num_layers: int\n    kv_channels: int\n    num_attention_heads: int\n    ffn_hidden_size: int\n    layernorm_epsilon: float\n    post_layer_norm: bool\n    rmsnorm: bool\n    add_bias_linear: bool\n    add_qkv_bias: bool\n    apply_query_key_layer_scaling: bool\n    multi_query_attention: bool\n    multi_query_group_num: int\n    vocab_size: int = 0\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    head_dim: int = 0\n    max_batch_size: int = 1\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.vocab_size == 0:\n            for name in [\"padded_vocab_size\"]:\n                if name in self.kwargs:\n                    self.vocab_size = self.kwargs.pop(name)\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"seq_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.head_dim * self.num_attention_heads == self.hidden_size\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass GLMAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: GLMConfig):\n        self.hidden_size = config.hidden_size\n        if config.num_attention_heads % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.num_attention_heads} attention heads\"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.num_heads = config.num_attention_heads // config.tensor_parallel_shards\n        self.multi_query_attention = config.multi_query_attention\n        self.num_key_value_heads = (\n            config.multi_query_group_num\n            if config.multi_query_attention\n            else config.num_attention_heads\n        ) // config.tensor_parallel_shards\n        self.head_dim = config.head_dim\n        self.query_key_value = nn.Linear(\n            config.hidden_size,\n            (2 * self.num_key_value_heads + self.num_heads) * self.head_dim,\n            bias=config.add_bias_linear or config.add_qkv_bias,\n        )\n        self.dense = nn.Linear(\n            self.num_heads * self.head_dim,\n            config.hidden_size,\n            bias=config.add_bias_linear,\n        )\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q, h_kv = self.head_dim, self.num_heads, self.num_key_value_heads\n        b, s, _ = hidden_states.shape\n        qkv = self.query_key_value(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, h_q, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        attn_output = self.dense(output)\n        return attn_output\n\n\nclass GLMMLP(nn.Module):\n    def __init__(self, config: GLMConfig):\n        if config.ffn_hidden_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split ffn hidden size {config.ffn_hidden_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.ffn_hidden_size = config.ffn_hidden_size // config.tensor_parallel_shards\n\n        self.dense_h_to_4h = nn.Linear(\n            config.hidden_size,\n            self.ffn_hidden_size * 2,\n            bias=config.add_bias_linear,\n        )\n        self.dense_4h_to_h = nn.Linear(\n            self.ffn_hidden_size,\n            config.hidden_size,\n            bias=config.add_bias_linear,\n        )\n\n        def swiglu(x):\n            x = nn.chunk(x, 2, dim=-1)\n            return nn.silu(x[0]) * x[1]\n\n        self.activation_func = swiglu\n\n    def forward(self, x):\n        intermediate_parallel = self.dense_h_to_4h(x)\n        intermediate_parallel = self.activation_func(intermediate_parallel)\n        output = self.dense_4h_to_h(intermediate_parallel)\n        return output\n\n\nclass GLMBlock(nn.Module):\n    def __init__(self, config: GLMConfig):\n        self.self_attention = GLMAttention(config=config)\n        self.mlp = GLMMLP(config)\n        self.input_layernorm = nn.RMSNorm(\n            config.hidden_size, -1, config.layernorm_epsilon, bias=False\n        )\n        self.post_attention_layernorm = nn.RMSNorm(\n            config.hidden_size, -1, config.layernorm_epsilon, bias=False\n        )\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attention.num_heads * hd\n            k = self.self_attention.num_key_value_heads * hd\n            v = self.self_attention.num_key_value_heads * hd\n            _set(\n                self.self_attention.query_key_value.weight,\n                tp.ShardSingleDim(\"_shard_qkv_weight\", dim=0, segs=[q, k, v]),\n            )\n            if config.add_bias_linear or config.add_qkv_bias:\n                _set(\n                    self.self_attention.query_key_value.bias,\n                    tp.ShardSingleDim(\"_shard_qkv_bias\", dim=0, segs=[q, k, v]),\n                )\n            _set(\n                self.self_attention.dense.weight,\n                tp.ShardSingleDim(\"_shard_dense_weight\", dim=1),\n            )\n            if config.add_bias_linear:\n                _set(\n                    self.self_attention.dense.bias,\n                    tp.ShardSingleDim(\"_shard_dense_bias\", dim=0),\n                )\n            _set(\n                self.mlp.dense_h_to_4h.weight,\n                tp.ShardSingleDim(\"_shard_dense_h_to_4h_weight\", dim=0),\n            )\n            if config.add_bias_linear:\n                _set(\n                    self.mlp.dense_h_to_4h.bias,\n                    tp.ShardSingleDim(\"_shard_dense_h_to_4h_bias\", dim=0),\n                )\n            _set(\n                self.mlp.dense_4h_to_h.weight,\n                tp.ShardSingleDim(\"_shard_dense_4h_to_h\", dim=1),\n            )\n            if config.add_bias_linear:\n                _set(\n                    self.mlp.dense_4h_to_h.bias,\n                    tp.ShardSingleDim(\"_shard_dense_4h_to_h_bias\", dim=1),\n                )\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.self_attention(self.input_layernorm(hidden_states), paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.mlp(self.post_attention_layernorm(hidden_states))\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass GLMTransformer(nn.Module):\n    \"\"\"Transformer class.\"\"\"\n\n    def __init__(self, config: GLMConfig):\n        self.post_layer_norm = config.post_layer_norm\n\n        # Number of layers.\n        self.num_layers = config.num_layers\n\n        # Transformer layers.\n        self.layers = nn.ModuleList([GLMBlock(config) for _ in range(config.num_layers)])\n\n        if self.post_layer_norm:\n            if config.rmsnorm:\n                self.final_layernorm = nn.RMSNorm(\n                    config.hidden_size, -1, config.layernorm_epsilon, bias=False\n                )\n            else:\n                self.final_layernorm = nn.LayerNorm(config.hidden_size, config.layernorm_epsilon)\n\n    def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = inputs\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.final_layernorm(hidden_states)\n        return hidden_states\n\n\nclass ChatGLMModel(nn.Module):\n    def __init__(self, config: GLMConfig):\n        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)\n        self.encoder = GLMTransformer(config)\n        self.output_layer = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n\n    def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = inputs\n        hidden_states = self.encoder(hidden_states, paged_kv_cache)\n        return hidden_states\n\n\nclass ChatGLMForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: GLMConfig):\n        self.transformer = ChatGLMModel(config)\n        self.num_hidden_layers = config.num_layers\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = (\n            config.multi_query_group_num\n            if config.multi_query_attention\n            else config.num_attention_heads\n        )\n        self.head_dim = config.head_dim\n        self.vocab_size = config.vocab_size\n        self.rope_theta = 10000\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.transformer(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.transformer.output_layer(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.transformer.embedding(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.transformer(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.transformer.output_layer(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.transformer(input_embed, paged_kv_cache)\n        logits = self.transformer.output_layer(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/cohere/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/cohere/cohere_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Cohere parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\nfrom mlc_llm.quantization import Quantization, make_awq_quant\n\nfrom .cohere_model import CohereConfig, CohereForCausalLM\n\nawq_quant = make_awq_quant(CohereForCausalLM)\n\n\ndef _cohere_name_transform(name: str) -> str:\n    if \"out_proj.\" in name:\n        return name.replace(\"out_proj.\", \"o_proj.\")\n    return name\n\n\nhuggingface = make_standard_hf_loader(\n    model_cls=CohereForCausalLM,\n    include_gate_up=False,\n    name_transform=_cohere_name_transform,\n)\n\n\n# https://huggingface.co/alijawad07/aya-23-8B-AWQ-GEMM/tree/main\ndef awq(model_config: CohereConfig, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of AWQ parameters.\n    Parameters\n    ----------\n    model_config : CohereConfig\n        The configuration of the Cohere model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to AWQ.\n    \"\"\"\n    model, _ = awq_quant(model_config, quantization)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),  # type: ignore[attr-defined]\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    def _add(mlc_name, hf_name):\n        mapping.add_mapping(\n            mlc_name,\n            [hf_name],\n            functools.partial(\n                lambda x, dtype: x.astype(dtype),\n                dtype=named_parameters[mlc_name].dtype,\n            ),\n        )\n\n    for i in range(model_config.num_hidden_layers):\n        # Add QKV in self attention\n        attn = f\"model.layers.{i}.self_attn\"\n        for quantize_suffix in [\"qweight\", \"qzeros\", \"scales\"]:\n            mlc_name = f\"{attn}.qkv_proj.{quantize_suffix}\"\n            assert mlc_name in named_parameters\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{attn}.q_proj.{quantize_suffix}\",\n                    f\"{attn}.k_proj.{quantize_suffix}\",\n                    f\"{attn}.v_proj.{quantize_suffix}\",\n                ],\n                functools.partial(\n                    lambda q, k, v, dtype: np.concatenate(\n                        [q, k, v],\n                        axis=1,  # AWQ GEMM would transpose the weight\n                    ).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n            _add(f\"{attn}.out_proj.{quantize_suffix}\", f\"{attn}.o_proj.{quantize_suffix}\")\n\n        # Concat gate and up in MLP\n        mlp = f\"model.layers.{i}.mlp\"\n        for quantize_suffix in [\"qweight\", \"qzeros\", \"scales\"]:\n            _add(f\"{mlp}.up_proj.{quantize_suffix}\", f\"{mlp}.up_proj.{quantize_suffix}\")\n            _add(\n                f\"{mlp}.gate_proj.{quantize_suffix}\",\n                f\"{mlp}.gate_proj.{quantize_suffix}\",\n            )\n            _add(\n                f\"{mlp}.down_proj.{quantize_suffix}\",\n                f\"{mlp}.down_proj.{quantize_suffix}\",\n            )\n\n        # inv_freq is not used in the model\n        # mapping.add_unused(f\"{attn}.rotary_emb.inv_freq\")\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype),\n            )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/cohere/cohere_model.py",
    "content": "\"\"\"\nImplementation for Aya23 architecture\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass CohereConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Cohere Aya-23 model\"\"\"\n\n    model_type: str  # cohere\n    hidden_size: int\n    vocab_size: int\n    num_hidden_layers: int\n    num_attention_heads: int\n    num_key_value_heads: int\n    intermediate_size: int\n    layer_norm_eps: float\n    position_embedding_base: int = 0\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    head_dim: int = 0\n    tensor_parallel_shards: int = 1\n    max_batch_size: int = 1\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.position_embedding_base == 0:\n            if \"rope_theta\" in self.kwargs:\n                self.position_embedding_base = self.kwargs[\"rope_theta\"]\n            else:\n                self.position_embedding_base = 10000\n\n            if self.context_window_size == 0:\n                for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                    if name in self.kwargs:\n                        self.context_window_size = self.kwargs.pop(name)\n                        logger.info(\n                            \"%s not found in config.json. Falling back to %s (%d)\",\n                            bold(\"context_window_size\"),\n                            bold(name),\n                            self.context_window_size,\n                        )\n                        break\n\n            if self.prefill_chunk_size == 0:\n                logger.info(\n                    \"%s defaults to %d\",\n                    bold(\"prefill_chunk_size\"),\n                    min(self.context_window_size, 8192),\n                )\n                self.prefill_chunk_size = min(self.context_window_size, 8192)\n            elif self.prefill_chunk_size > self.context_window_size:\n                logger.info(\n                    \"Overriding %s from %d to %d\",\n                    bold(\"prefill_chunk_size\"),\n                    self.prefill_chunk_size,\n                    min(self.context_window_size, 8192),\n                )\n                self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n            if self.num_key_value_heads == 0 or self.num_key_value_heads is None:\n                self.num_key_value_heads = self.num_attention_heads\n            if self.head_dim == 0:\n                self.head_dim = self.hidden_size // self.num_attention_heads\n            assert (\n                self.head_dim * self.num_attention_heads == self.hidden_size\n            ), \"head_dim * num_attention_heads != hidden_size\"\n            assert (\n                self.num_attention_heads % self.num_key_value_heads == 0\n            ), \"num_attention_heads % num_key_value_heads != 0\"\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass CohereMLP(nn.Module):\n    def __init__(self, config: CohereConfig):\n        super().__init__()\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.gate_proj = nn.Linear(config.hidden_size, self.intermediate_size, bias=False)\n        self.up_proj = nn.Linear(config.hidden_size, self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)\n\n    def forward(self, x):\n        down_proj = self.down_proj(op.silu(self.gate_proj(x)) * self.up_proj(x))\n        return down_proj\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass CohereAttention(nn.Module):\n    def __init__(self, config: CohereConfig):\n        self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards\n        assert config.num_attention_heads % config.tensor_parallel_shards == 0, (\n            f\"num_attention_heads({config.num_attention_heads}) \"\n            \"must be divisible by tensor_parallel_shards\"\n        )\n        self.num_key_value_heads = config.num_key_value_heads // config.tensor_parallel_shards\n        assert config.num_key_value_heads % config.tensor_parallel_shards == 0, (\n            f\"num_attention_heads({config.num_key_value_heads}) \"\n            \"must be divisible by tensor_parallel_shards\"\n        )\n        self.head_dim = config.head_dim\n\n        self.qkv_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=(self.num_q_heads + 2 * self.num_key_value_heads) * self.head_dim,\n            bias=False,\n        )\n        self.out_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False)\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_key_value_heads\n        b, s, _ = hidden_states.shape\n        # QKV Projection\n        qkv = self.qkv_proj(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        # Attention\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_q_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        return self.out_proj(output)\n\n\nclass CohereDecoderLayer(nn.Module):\n    def __init__(self, config: CohereConfig):\n        super().__init__()\n        self.self_attn = CohereAttention(config)\n        self.mlp = CohereMLP(config)\n        self.input_layernorm = CohereNorm(config.hidden_size, eps=config.layer_norm_eps)\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.weight.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_q_heads * hd\n            k = self.self_attn.num_key_value_heads * hd\n            v = self.self_attn.num_key_value_heads * hd\n            i = self.mlp.intermediate_size\n            _set(\n                self.self_attn.qkv_proj,\n                tp.ShardSingleDim(\"_shard_qkv\", segs=[q, k, v], dim=0),\n            )\n            _set(self.self_attn.out_proj, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.gate_proj,\n                tp.ShardSingleDim(\"_shard_mlp_gate\", segs=[i, i], dim=0),\n            )\n            _set(self.mlp.up_proj, tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0))\n            _set(self.mlp.down_proj, tp.ShardSingleDim(\"_shard_mlp_down\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        hidden_ln = self.input_layernorm(hidden_states)\n        attn = self.self_attn(hidden_ln, paged_kv_cache, layer_id)\n        mlp = self.mlp(hidden_ln)\n        hidden_states = self._apply_parallel_residual(attn, residual=hidden_states)  # type: ignore\n        hidden_states = self._apply_parallel_residual(mlp, residual=hidden_states)  # type: ignore\n        return hidden_states\n\n    def _apply_parallel_residual(self, mlp_out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(mlp_out + residual / self.tensor_parallel_shards, \"sum\")\n        return mlp_out + residual\n\n\nclass CohereNorm(nn.Module):\n    def __init__(\n        self, normalized_shape: int, eps: float = 1e-5, dtype: Optional[str] = None\n    ) -> None:\n        super().__init__()\n        self.normalized_shape = normalized_shape\n        self.eps = eps\n        self.weight = nn.Parameter((normalized_shape,), dtype=dtype)\n\n    def forward(self, x: Tensor) -> Tensor:\n        return op.layer_norm(\n            x,\n            normalized_shape=self.normalized_shape,\n            weight=self.weight,\n            bias=None,\n            eps=self.eps,\n        )\n\n\nclass CohereEmbedding(nn.Embedding):\n    def lm_head_forward(self, x: nn.Tensor):\n        \"\"\"The lm_head forwarding, which transposes the weight and multiplies\n        with the input tensor.\n        \"\"\"\n        weight = nn.op.permute_dims(self.weight)\n        return nn.op.matmul(x, weight, out_dtype=\"float32\")\n\n\nclass CohereModel(nn.Module):\n    def __init__(self, config: CohereConfig):\n        assert config.hidden_size % config.num_attention_heads == 0\n        self.embed_tokens = CohereEmbedding(\"vocab_size\", config.hidden_size)\n        self.layers = nn.ModuleList(\n            [CohereDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = CohereNorm(config.hidden_size, eps=config.layer_norm_eps)\n\n    def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = input_embed\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass CohereForCausalLM(nn.Module):\n    # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: CohereConfig) -> None:\n        super().__init__()\n        self.model = CohereModel(config)\n        self.num_hidden_layers = config.num_hidden_layers\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.head_dim = config.head_dim\n        self.hidden_size = config.hidden_size\n        self.vocab_size = config.vocab_size\n        self.rope_theta = config.position_embedding_base\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        lm_logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        if lm_logits.dtype != \"float32\":\n            lm_logits = lm_logits.astype(\"float32\")\n        return lm_logits\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):\n            b, s, d = x.shape  # type: ignore\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        # logits = self.lm_head(hidden_states)\n        logits = self.model.embed_tokens.lm_head_forward(hidden_states)  # type: ignore\n\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)  # type: ignore\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)  # type: ignore\n        embeds = self.model.embed_tokens(input_ids)\n        return embeds\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)  # type: ignore\n"
  },
  {
    "path": "python/mlc_llm/model/deepseek/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/deepseek/deepseek_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Deepseek parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.quantization import Quantization\n\nfrom .deepseek_model import DeepseekConfig, DeepseekForCausalLM\n\n\ndef huggingface(model_config: DeepseekConfig, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : MiniCPMConfig\n        The configuration of the MiniCPM model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = DeepseekForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    for i in range(model_config.num_hidden_layers):\n        # map attention weight\n        attn = f\"model.layers.{i}.self_attn\"\n        for weight_type in [\"weight\"]:\n            mlc_name = f\"{attn}.wqkv_pack.{weight_type}\"\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{attn}.q_proj.{weight_type}\",\n                    f\"{attn}.k_proj.{weight_type}\",\n                    f\"{attn}.v_proj.{weight_type}\",\n                ],\n                functools.partial(\n                    lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n    for i in range(model_config.num_hidden_layers):\n        if i >= model_config.first_k_dense_replace and i % model_config.moe_layer_freq == 0:\n            # map mlp shared expert weight\n            mlp = f\"model.layers.{i}.mlp\"\n            shared_expert = f\"{mlp}.shared_experts\"\n            mlc_name = f\"{shared_expert}.gate_up_proj.weight\"\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{shared_expert}.gate_proj.weight\",\n                    f\"{shared_expert}.up_proj.weight\",\n                ],\n                functools.partial(\n                    lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n            # map mlp moe gate and up weight\n            mlc_name = f\"{mlp}.moe_gate_up_proj.weight\"\n\n            def combine_expert_gate_up(*hf_params, dtype):\n                stack = []\n                for i in range(0, len(hf_params), 2):\n                    stack.append(np.concatenate([hf_params[i], hf_params[i + 1]], axis=0))\n                return np.stack(stack, axis=0).astype(dtype)\n\n            mapping.add_mapping(\n                mlc_name,\n                functools.reduce(\n                    lambda a, b: a + b,\n                    [\n                        [\n                            f\"{mlp}.experts.{expert_id}.gate_proj.weight\",\n                            f\"{mlp}.experts.{expert_id}.up_proj.weight\",\n                        ]\n                        for expert_id in range(model_config.n_routed_experts)\n                    ],\n                ),\n                functools.partial(\n                    combine_expert_gate_up,\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n            # map mlp moe gate and up weight\n            mlc_name = f\"{mlp}.moe_down_proj.weight\"\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{mlp}.experts.{expert_id}.down_proj.weight\"\n                    for expert_id in range(model_config.n_routed_experts)\n                ],\n                functools.partial(\n                    lambda *hf_params, dtype: np.stack(hf_params, axis=0).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n        else:\n            # map mlp weight\n            mlp = f\"model.layers.{i}.mlp\"\n            mlc_name = f\"{mlp}.gate_up_proj.weight\"\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{mlp}.gate_proj.weight\",\n                    f\"{mlp}.up_proj.weight\",\n                ],\n                functools.partial(\n                    lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/deepseek/deepseek_model.py",
    "content": "\"\"\"\nImplementation for Deepseek architecture.\n\"\"\"\n\nimport dataclasses\nfrom functools import partial\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.nn.expert import MixtralExperts\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass DeepseekConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Deepseek model.\"\"\"\n\n    vocab_size: int\n    hidden_size: int\n    intermediate_size: int\n    moe_intermediate_size: int\n    num_hidden_layers: int\n    num_attention_heads: int\n    num_key_value_heads: int\n    n_shared_experts: int\n    n_routed_experts: int\n    moe_layer_freq: int\n    first_k_dense_replace: int\n    hidden_act: str\n    norm_topk_prob: bool\n    attention_bias: bool\n    rms_norm_eps: float\n    use_cache: bool\n    bos_token_id: int\n    eos_token_id: int\n    tie_word_embeddings: bool = False\n    rope_theta: int = 10000\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    head_dim: int = 0\n    max_batch_size: int = 1\n    num_experts_per_tok: int = 0\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.head_dim * self.num_attention_heads == self.hidden_size\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass DeepseekAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: DeepseekConfig):\n        super().__init__()  # Make sure to call the parent class constructor\n        self.hidden_size = config.hidden_size\n        self.rope_theta = config.rope_theta\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        if config.num_attention_heads % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.num_attention_heads} attention heads \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n\n        self.attention_bias = config.attention_bias\n        self.num_heads = config.num_attention_heads // self.tensor_parallel_shards\n        self.num_key_value_heads = config.num_key_value_heads // self.tensor_parallel_shards\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.head_dim = config.head_dim\n        self.max_position_embeddings = config.context_window_size\n\n        self.wqkv_pack = nn.Linear(\n            in_features=self.hidden_size,\n            out_features=(self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,\n            bias=self.attention_bias,\n        )\n        self.o_proj = nn.Linear(\n            self.num_heads * self.head_dim, self.hidden_size, bias=self.attention_bias\n        )\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q, h_kv = self.head_dim, self.num_heads, self.num_key_value_heads\n        b, s, _ = hidden_states.shape\n        qkv = self.wqkv_pack(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        attn_output = self.o_proj(output)\n        return attn_output\n\n\nACT2FN = {\n    \"gelu\": partial(nn.gelu, approximate=False),\n    \"relu\": nn.relu,\n    \"silu\": nn.silu,\n    \"swish\": nn.silu,\n    \"gelu_new\": partial(nn.gelu, approximate=True),\n}\n\n\nclass DeepseekMLP(nn.Module):\n    def __init__(self, config: DeepseekConfig, intermediate_size=None):\n        self.hidden_size = config.hidden_size\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = (\n            config.intermediate_size if intermediate_size is None else intermediate_size\n        ) // config.tensor_parallel_shards\n\n        self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x: Tensor):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(op.silu(x1) * x2)\n\n\nclass DeepseekMoE(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: DeepseekConfig):\n        self.num_local_experts = config.n_routed_experts\n        self.num_experts_per_tok = config.num_experts_per_tok\n        self.gate = nn.Linear(config.hidden_size, config.n_routed_experts, bias=False)\n        self.norm_topk_prob = config.norm_topk_prob\n        self.moe_intermediate_size = config.moe_intermediate_size // config.tensor_parallel_shards\n        self.moe_gate_up_proj = MixtralExperts(\n            self.num_local_experts,\n            in_features=config.hidden_size,\n            out_features=2 * self.moe_intermediate_size,\n            tensor_parallel_shards=config.tensor_parallel_shards,\n        )\n        self.moe_down_proj = MixtralExperts(\n            self.num_local_experts,\n            in_features=self.moe_intermediate_size,\n            out_features=config.hidden_size,\n            tensor_parallel_shards=config.tensor_parallel_shards,\n        )\n        self.dtype = \"float32\"\n\n        if config.n_shared_experts is not None:\n            intermediate_size = self.moe_intermediate_size * config.n_shared_experts\n            self.shared_experts = DeepseekMLP(config, intermediate_size=intermediate_size)\n\n    def forward(self, x: Tensor):  # pylint: disable=too-many-locals\n        def _expert_forward(x: Tensor, indptr: Tensor):\n            x1_x3 = self.moe_gate_up_proj(x, indptr)\n            x1, x3 = op.split(x1_x3, indices_or_sections=2, axis=-1)\n            x = self.moe_down_proj(op.silu(x1) * x3, indptr)\n            return x\n\n        experts_per_tok = self.num_experts_per_tok\n        num_experts = self.num_local_experts\n        b, s, h = x.shape\n        num_tokens = b * s\n        x = op.reshape(x, (num_tokens, h))\n        gate = self.gate(x)  # (b * s, num_routed_experts)\n        expert_weights, expert_indices = op_ext.moe_misc.gating_softmax_topk(\n            gate, experts_per_tok, norm_topk_prob=self.norm_topk_prob\n        )\n\n        if num_tokens == 1:\n            # x: [num_tokens * experts_per_tok, hidden_size]\n            moe_hidden_states = _expert_forward(x, expert_indices)\n        else:\n            # cumsum: [num_tokens * local_experts]\n            cumsum = op_ext.moe_misc.moe_cumsum(expert_indices, num_experts)\n            # indices: [num_tokens * experts_per_tok]\n            reverse_indices, token_indices = op_ext.moe_misc.get_indices(cumsum, expert_indices)\n            # indptr: [num_local_experts + 1]\n            indptr = op_ext.moe_misc.get_indptr(\n                cumsum, num_experts, num_tokens, inclusive=False, out_dtype=\"int32\"\n            )\n            # x: [num_tokens * experts_per_tok, hidden_size]\n            moe_hidden_states = op.take(x, token_indices, axis=0)\n            moe_hidden_states = _expert_forward(moe_hidden_states, indptr)\n            moe_hidden_states = op_ext.moe_misc.scatter_output(moe_hidden_states, reverse_indices)\n\n        # moe_hidden_states: [num_tokens, experts_per_tok, hidden_size]\n        expert_weights = expert_weights.reshape(num_tokens, experts_per_tok, 1)\n        moe_hidden_states = (\n            moe_hidden_states.reshape(num_tokens, experts_per_tok, h) * expert_weights\n        )\n        # moe_hidden_states: [num_tokens, hidden_size]\n        moe_hidden_states = op_ext.moe_misc.moe_sum(moe_hidden_states, dim=1)\n\n        shared_expert_hidden_states = self.shared_experts(x)\n\n        final_hidden_states = moe_hidden_states + shared_expert_hidden_states\n        final_hidden_states = op.reshape(final_hidden_states, (b, s, h))\n        return final_hidden_states\n\n\nclass DeepseekDecoderLayer(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: DeepseekConfig, layer_idx: int):\n        self.hidden_size = config.hidden_size\n        self.num_hidden_layers = config.num_hidden_layers\n        self.self_attn = DeepseekAttention(config)\n        self.num_experts = config.n_routed_experts\n        self.mlp = (\n            DeepseekMoE(config)\n            if (\n                config.n_routed_experts is not None\n                and layer_idx >= config.first_k_dense_replace\n                and layer_idx % config.moe_layer_freq == 0\n            )\n            else DeepseekMLP(config)\n        )\n        self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n        self.post_attention_layernorm = nn.RMSNorm(\n            config.hidden_size, -1, config.rms_norm_eps, bias=False\n        )\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_heads * hd\n            k = self.self_attn.num_key_value_heads * hd\n            v = self.self_attn.num_key_value_heads * hd\n\n            if (\n                config.n_routed_experts is not None\n                and layer_idx >= config.first_k_dense_replace\n                and layer_idx % config.moe_layer_freq == 0\n            ):\n                i = self.mlp.moe_intermediate_size\n            else:\n                i = self.mlp.intermediate_size\n            _set(\n                self.self_attn.wqkv_pack.weight,\n                tp.ShardSingleDim(\"_shard_qkv_weight\", dim=0, segs=[q, k, v]),\n            )\n            _set(self.self_attn.o_proj.weight, tp.ShardSingleDim(\"_shard_o\", dim=1))\n\n            if (\n                config.n_routed_experts is not None\n                and layer_idx >= config.first_k_dense_replace\n                and layer_idx % config.moe_layer_freq == 0\n            ):\n                _set(\n                    self.mlp.moe_gate_up_proj.weight,\n                    tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=1),\n                )\n                _set(\n                    self.mlp.moe_down_proj.weight,\n                    tp.ShardSingleDim(\"_shard_mlp_down\", dim=2),\n                )\n\n            else:\n                _set(\n                    self.mlp.gate_up_proj.weight,\n                    tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n                )\n                _set(\n                    self.mlp.down_proj.weight,\n                    tp.ShardSingleDim(\"_shard_mlp_down\", dim=1),\n                )\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.input_layernorm(hidden_states)\n        out = self.self_attn(out, paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.post_attention_layernorm(hidden_states)\n        out = self.mlp(out)  # type: ignore[operator]\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass DeepseekModel(nn.Module):\n    def __init__(self, config: DeepseekConfig):\n        assert config.hidden_size % config.num_attention_heads == 0\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)\n        self.layers = nn.ModuleList(\n            [\n                DeepseekDecoderLayer(config, layer_idx)\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n\n    def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = inputs\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass DeepseekForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: DeepseekConfig):\n        self.model = DeepseekModel(config)\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.vocab_size = config.vocab_size\n        self.num_hidden_layers = config.num_hidden_layers\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.head_dim = config.head_dim\n        self.vocab_size = config.vocab_size\n        self.rope_theta = config.rope_theta\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/deepseek_v2/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/deepseek_v2/deepseek_v2_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Deepseek-V2 parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\nfrom typing import Callable, List\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping, QuantizeMapping\nfrom mlc_llm.quantization import BlockScaleQuantize, Quantization\n\nfrom .deepseek_v2_model import DeepseekV2Config, DeepseekV2ForCausalLM\n\n\ndef huggingface(  # pylint: disable=too-many-locals,too-many-statements\n    model_config: DeepseekV2Config, quantization: Quantization\n) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : DeepseekV2Config\n        The configuration of the DeepseekV2 model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = DeepseekV2ForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    if isinstance(quantization, BlockScaleQuantize):\n        # Convert the model to block-scale quantized model before loading parameters\n        model = quantization.quantize_model(model, QuantizeMapping({}, {}), \"\")\n        if model_config.weight_block_size is None:\n            raise ValueError(\n                \"The input DeepSeek model is not fp8 block quantized. \"\n                \"Thus BlockScaleQuantize is not supported.\"\n            )\n\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    if (\n        not isinstance(quantization, BlockScaleQuantize)\n        and model_config.weight_block_size is not None\n    ):\n        raise ValueError(\n            \"The input DeepSeek model is fp8 block quantized. \"\n            \"Please use BlockScaleQuantize for the model.\"\n        )\n\n    # Helper function to add both weight and scale mappings\n    def add_weight_and_scale_mapping(\n        weight_mlc_name: str,\n        weight_hf_names: List[str],\n        weight_transform_func: Callable,\n    ):\n        mlc_param = named_parameters[weight_mlc_name]\n        mapping.add_mapping(\n            weight_mlc_name,\n            weight_hf_names,\n            functools.partial(weight_transform_func, dtype=mlc_param.dtype),\n        )\n\n        if isinstance(quantization, BlockScaleQuantize):\n            scale_mlc_name = f\"{weight_mlc_name}_scale_inv\"\n            if scale_mlc_name in named_parameters:\n                scale_hf_names = [f\"{name}_scale_inv\" for name in weight_hf_names]\n                scale_param = named_parameters[scale_mlc_name]\n                mapping.add_mapping(\n                    scale_mlc_name,\n                    scale_hf_names,\n                    functools.partial(weight_transform_func, dtype=scale_param.dtype),\n                )\n\n    for i in range(model_config.num_hidden_layers):\n        if i >= model_config.first_k_dense_replace and i % model_config.moe_layer_freq == 0:\n            # map mlp shared expert weight\n            mlp = f\"model.layers.{i}.mlp\"\n            shared_expert = f\"{mlp}.shared_experts\"\n            add_weight_and_scale_mapping(\n                f\"{shared_expert}.gate_up_proj.weight\",\n                [\n                    f\"{shared_expert}.gate_proj.weight\",\n                    f\"{shared_expert}.up_proj.weight\",\n                ],\n                lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),\n            )\n\n            # map mlp moe gate and up weight\n            def combine_expert_gate_up(*hf_params, dtype):\n                stack = []\n                for i in range(0, len(hf_params), 2):\n                    stack.append(np.concatenate([hf_params[i], hf_params[i + 1]], axis=0))\n                return np.stack(stack, axis=0).astype(dtype)\n\n            add_weight_and_scale_mapping(\n                f\"{mlp}.moe_gate_up_proj.weight\",\n                functools.reduce(\n                    lambda a, b: a + b,\n                    [\n                        [\n                            f\"{mlp}.experts.{expert_id}.gate_proj.weight\",\n                            f\"{mlp}.experts.{expert_id}.up_proj.weight\",\n                        ]\n                        for expert_id in range(model_config.n_routed_experts)\n                    ],\n                ),\n                combine_expert_gate_up,\n            )\n\n            # map mlp moe down projection weight\n            add_weight_and_scale_mapping(\n                f\"{mlp}.moe_down_proj.weight\",\n                [\n                    f\"{mlp}.experts.{expert_id}.down_proj.weight\"\n                    for expert_id in range(model_config.n_routed_experts)\n                ],\n                lambda *hf_params, dtype: np.stack(hf_params, axis=0).astype(dtype),\n            )\n\n            # map moe e_score_correction_bias\n            if model_config.topk_method == \"noaux_tc\":\n                mlc_name = f\"{mlp}.e_score_correction_bias\"\n                mlc_param = named_parameters[mlc_name]\n                mapping.add_mapping(\n                    mlc_name,\n                    [f\"{mlp}.gate.e_score_correction_bias\"],\n                    functools.partial(\n                        lambda x, dtype: x.astype(dtype),\n                        dtype=mlc_param.dtype,\n                    ),\n                )\n        else:\n            # map mlp weight\n            mlp = f\"model.layers.{i}.mlp\"\n            add_weight_and_scale_mapping(\n                f\"{mlp}.gate_up_proj.weight\",\n                [\n                    f\"{mlp}.gate_proj.weight\",\n                    f\"{mlp}.up_proj.weight\",\n                ],\n                lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),\n            )\n\n        # map MLA kv_b_proj weight\n        attn = f\"model.layers.{i}.self_attn\"\n        mlc_name = f\"{attn}.w_uk\"\n        mlc_param = named_parameters[mlc_name]\n        mapping.add_mapping(\n            mlc_name,\n            [f\"{attn}.kv_b_proj.weight\"],\n            functools.partial(\n                lambda kv_b_proj, dtype: (\n                    np.split(\n                        kv_b_proj.reshape(\n                            model_config.num_key_value_heads,\n                            model_config.qk_nope_head_dim + model_config.v_head_dim,\n                            model_config.kv_lora_rank,\n                        ),\n                        indices_or_sections=[model_config.qk_nope_head_dim],\n                        axis=1,\n                    )[0]\n                    .transpose(0, 2, 1)\n                    .astype(dtype)\n                ),\n                dtype=mlc_param.dtype,\n            ),\n        )\n        if isinstance(quantization, BlockScaleQuantize):\n            scale_mlc_name = f\"{attn}.w_uk_scale_inv\"\n            mlc_param = named_parameters[scale_mlc_name]\n            mapping.add_mapping(\n                scale_mlc_name,\n                [f\"{attn}.kv_b_proj.weight_scale_inv\"],\n                functools.partial(\n                    lambda kv_b_proj, dtype: (\n                        np.split(\n                            kv_b_proj.reshape(\n                                model_config.num_key_value_heads,\n                                (model_config.qk_nope_head_dim + model_config.v_head_dim)\n                                // quantization.weight_block_size[0],\n                                model_config.kv_lora_rank // quantization.weight_block_size[1],\n                            ),\n                            indices_or_sections=[\n                                model_config.qk_nope_head_dim // quantization.weight_block_size[0]\n                            ],\n                            axis=1,\n                        )[0]\n                        .transpose(0, 2, 1)\n                        .astype(dtype)\n                    ),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n        mlc_name = f\"{attn}.w_uv\"\n        mlc_param = named_parameters[mlc_name]\n        mapping.add_mapping(\n            mlc_name,\n            [f\"{attn}.kv_b_proj.weight\"],\n            functools.partial(\n                lambda kv_b_proj, dtype: np.split(\n                    kv_b_proj.reshape(\n                        model_config.num_key_value_heads,\n                        model_config.qk_nope_head_dim + model_config.v_head_dim,\n                        model_config.kv_lora_rank,\n                    ),\n                    indices_or_sections=[model_config.qk_nope_head_dim],\n                    axis=1,\n                )[1].astype(dtype),\n                dtype=mlc_param.dtype,\n            ),\n        )\n        if isinstance(quantization, BlockScaleQuantize):\n            scale_mlc_name = f\"{attn}.w_uv_scale_inv\"\n            mlc_param = named_parameters[scale_mlc_name]\n            mapping.add_mapping(\n                scale_mlc_name,\n                [f\"{attn}.kv_b_proj.weight_scale_inv\"],\n                functools.partial(\n                    lambda kv_b_proj, dtype: np.split(\n                        kv_b_proj.reshape(\n                            model_config.num_key_value_heads,\n                            (model_config.qk_nope_head_dim + model_config.v_head_dim)\n                            // quantization.weight_block_size[0],\n                            model_config.kv_lora_rank // quantization.weight_block_size[1],\n                        ),\n                        indices_or_sections=[\n                            model_config.qk_nope_head_dim // quantization.weight_block_size[0]\n                        ],\n                        axis=1,\n                    )[1].astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/deepseek_v2/deepseek_v2_model.py",
    "content": "\"\"\"\nImplementation for Deepseek V2 architecture\n\"\"\"\n\nimport dataclasses\nimport math\nfrom typing import Any, Dict, Literal, Optional, Tuple\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\nfrom tvm.relax.frontend.nn.llm import position_embedding\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.nn.expert import MixtralExperts\nfrom mlc_llm.op import batch_matmul\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass DeepseekV2Config(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Deepseek V2 model.\"\"\"\n\n    vocab_size: int\n    hidden_size: int\n    intermediate_size: int\n    moe_intermediate_size: int\n    num_hidden_layers: int\n    num_attention_heads: int\n    num_key_value_heads: int\n    n_shared_experts: int\n    n_routed_experts: int\n    num_experts_per_tok: int\n    norm_topk_prob: bool\n    first_k_dense_replace: int\n    moe_layer_freq: int\n    routed_scaling_factor: float\n    scoring_func: str\n    topk_method: Literal[\"greedy\", \"group_limited_greedy\", \"noaux_tc\"]\n    n_group: int\n    topk_group: int\n    attention_bias: bool\n    kv_lora_rank: int\n    qk_rope_head_dim: int\n    v_head_dim: int\n    qk_nope_head_dim: int\n    rms_norm_eps: float\n    rope_theta: int\n    q_lora_rank: Optional[int] = None\n    rope_scaling: Optional[Dict[str, Any]] = None\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    dtype: str = \"float32\"\n    max_batch_size: int = 1\n    weight_block_size: Optional[Tuple[int, int]] = None\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if \"quantization_config\" in self.kwargs:\n            quantization_config = self.kwargs.get(\"quantization_config\")\n            if (\n                isinstance(quantization_config, dict)\n                and quantization_config.get(\"activation_scheme\", \"\") == \"dynamic\"\n                and quantization_config.get(\"fmt\", \"\") == \"e4m3\"\n                and quantization_config.get(\"quant_method\", \"\") == \"fp8\"\n                and \"weight_block_size\" in quantization_config\n            ):\n                self.weight_block_size = quantization_config.get(\"weight_block_size\")\n                if (\n                    not isinstance(self.weight_block_size, (tuple, list))\n                    or len(self.weight_block_size) != 2\n                ):\n                    raise ValueError(\n                        \"Invalid DeepSeek model quantization config: \"\n                        \"weight_block_size must be a tuple of two integers, \"\n                        f\"got {self.weight_block_size} of type {type(self.weight_block_size)}\"\n                    )\n            else:\n                raise ValueError(\n                    \"Invalid DeepSeek model quantization config: unrecognized quantization config: \"\n                    f\"{quantization_config}\"\n                )\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        assert self.num_attention_heads % self.num_key_value_heads == 0\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 2048),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 2048)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 2048),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 2048)\n\n\n# pylint: disable=invalid-name,missing-docstring,too-many-locals\n\n\nclass DeepseekV2MLP(nn.Module):\n    def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None):\n        super().__init__()\n        self.hidden_size = config.hidden_size if hidden_size is None else hidden_size\n        intermediate_size = (\n            config.intermediate_size if intermediate_size is None else intermediate_size\n        )\n        if intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MoE intermediate size {intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = intermediate_size // config.tensor_parallel_shards\n\n        self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n\n    def forward(self, x: Tensor) -> Tensor:\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(op.silu(x1) * x2)\n\n\ndef yarn_get_mscale(scale=1, mscale=1):\n    if scale <= 1:\n        return 1.0\n    return 0.1 * mscale * math.log(scale) + 1.0\n\n\nclass DeepseekV2YarnRotaryEmbedding(nn.Module):\n    def __init__(self, config: DeepseekV2Config):\n        self.rope_fn = position_embedding.switch_rope_freq_func(config.rope_scaling)\n        self.rotary_dim = config.qk_rope_head_dim\n        self.theta = config.rope_theta\n\n    def forward(\n        self,\n        q: Tensor,\n        k: Tensor,\n        positions: Tensor,\n    ):\n        def _rope_fused(x: te.Tensor, positions: te.Tensor):\n            _, _, _, d_dim = x.shape\n            d_dim_half = d_dim // 2\n            dtype = x.dtype\n\n            def compute(b: tir.Var, s: tir.Var, h: tir.Var, d: tir.Var):\n                d1 = d // d_dim_half\n                d2 = d % d_dim_half\n\n                cos_freq, sin_freq, var_map = self.rope_fn(\n                    positions[s], d, self.rotary_dim, self.theta, dtype\n                )\n                cos = x[b, s, h, d2 * 2 + d1] * cos_freq\n\n                partner_d = tir.if_then_else(\n                    d < self.rotary_dim // 2,\n                    d + self.rotary_dim // 2,\n                    d - self.rotary_dim // 2,\n                )\n\n                partner_d1 = partner_d // d_dim_half\n                partner_d2 = partner_d % d_dim_half\n                sin = (\n                    x[b, s, h, partner_d2 * 2 + partner_d1]\n                    * sin_freq\n                    * tir.if_then_else(\n                        d < self.rotary_dim // 2,\n                        tir.const(-1, dtype),\n                        tir.const(1, dtype),\n                    )\n                )\n                expr = cos + sin\n                for var, val in var_map.items():\n                    expr = tir.Let(var, val, expr)\n                return expr\n\n            return te.compute(x.shape, compute, name=\"yarn_rope\")\n\n        q_embed = op.tensor_expr_op(_rope_fused, \"rope\", [q, positions])\n        k_embed = op.tensor_expr_op(_rope_fused, \"rope\", [k, positions])\n        return q_embed, k_embed\n\n\nclass DeepseekV2Attention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: DeepseekV2Config):\n        super().__init__()\n        self.config = config\n        self.hidden_size = config.hidden_size\n        if config.num_attention_heads % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.num_attention_heads} attention heads \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.num_heads = config.num_attention_heads // config.tensor_parallel_shards\n\n        self.rope_theta = config.rope_theta\n        self.q_lora_rank = config.q_lora_rank\n        self.qk_rope_head_dim = config.qk_rope_head_dim\n        self.kv_lora_rank = config.kv_lora_rank\n        self.v_head_dim = config.v_head_dim\n        self.qk_nope_head_dim = config.qk_nope_head_dim\n        self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim\n        self.block_size = config.weight_block_size\n\n        if self.q_lora_rank is None:\n            self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias=False)\n        else:\n            self.q_a_proj = nn.Linear(\n                self.hidden_size, config.q_lora_rank, bias=config.attention_bias\n            )\n            self.q_a_layernorm = nn.RMSNorm(config.q_lora_rank, -1, config.rms_norm_eps, bias=False)\n            self.q_b_proj = nn.Linear(\n                config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False\n            )\n\n        self.kv_a_proj_with_mqa = nn.Linear(\n            self.hidden_size,\n            config.kv_lora_rank + config.qk_rope_head_dim,\n            bias=config.attention_bias,\n        )\n        self.kv_a_layernorm = nn.RMSNorm(config.kv_lora_rank, -1, config.rms_norm_eps, bias=False)\n        self.kv_b_proj = nn.Linear(\n            config.kv_lora_rank,\n            self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim),\n            bias=False,\n        )\n        self.w_uk = nn.Parameter((self.num_heads, config.kv_lora_rank, self.qk_nope_head_dim))\n        self.w_uv = nn.Parameter((self.num_heads, self.v_head_dim, config.kv_lora_rank))\n\n        self.o_proj = nn.Linear(\n            self.num_heads * self.v_head_dim,\n            self.hidden_size,\n            bias=config.attention_bias,\n        )\n        self.softmax_scale = self.q_head_dim ** (-0.5)\n        if self.config.rope_scaling is not None:\n            mscale_all_dim = self.config.rope_scaling.get(\"mscale_all_dim\", 0)\n            scaling_factor = self.config.rope_scaling[\"factor\"]\n            if mscale_all_dim:\n                mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)\n                self.softmax_scale = self.softmax_scale * mscale * mscale\n        self.rotary_emb = DeepseekV2YarnRotaryEmbedding(config)\n\n    def forward(  # pylint: disable=too-many-arguments\n        self,\n        hidden_states: Tensor,\n        paged_kv_cache: PagedKVCache,\n        layer_id: int,\n        query_positions: Tensor,\n        forward_mode: Literal[\"prefill\", \"decode\", \"extend\"],\n    ) -> Tuple[Tensor, PagedKVCache]:\n        b, s, _ = hidden_states.shape\n\n        if self.q_lora_rank is None:\n            q = self.q_proj(hidden_states)\n        else:\n            q = self.q_b_proj(\n                self.q_a_layernorm(self.q_a_proj(hidden_states))\n            )  # (b, s, num_heads * q_head_dim)\n        q = op.reshape(q, (b, s, self.num_heads, self.q_head_dim))  # (b, s, num_heads, q_head_dim)\n        q_nope, q_pe = op.split(\n            q, [self.qk_nope_head_dim], axis=-1\n        )  # (b, s, num_heads, qk_nope_head_dim), (b, s, num_heads, qk_rope_head_dim)\n        compressed_kv = self.kv_a_proj_with_mqa(hidden_states).reshape(\n            b, s, 1, self.kv_lora_rank + self.qk_rope_head_dim\n        )  # (b, s, 1, kv_lora_rank + qk_rope_head_dim)\n        compressed_kv, k_pe = op.split(\n            compressed_kv, [self.config.kv_lora_rank], axis=-1\n        )  # (b, s, 1, kv_lora_rank), (b, s, 1, qk_rope_head_dim)\n        compressed_kv = self.kv_a_layernorm(compressed_kv)\n        q_pe, k_pe = self.rotary_emb(q_pe, k_pe, query_positions)\n        kv_states = op.concat(\n            [compressed_kv, k_pe], dim=-1\n        )  # (b, s, 1, kv_lora_rank + qk_rope_head_dim)\n        paged_kv_cache = paged_kv_cache.append_mla_kv(layer_id, kv_states)\n\n        if forward_mode == \"prefill\":\n            output, _ = self.self_attn(q_nope, compressed_kv, q_pe, k_pe, paged_kv_cache, layer_id)\n        elif forward_mode == \"decode\":\n            output, _ = self.cross_attn(q_nope, q_pe, paged_kv_cache, layer_id)\n        elif forward_mode == \"extend\":\n            o1, lse1 = self.self_attn(q_nope, compressed_kv, q_pe, k_pe, paged_kv_cache, layer_id)\n            o2, lse2 = self.cross_attn(q_nope, q_pe, paged_kv_cache, layer_id)\n            output, _ = paged_kv_cache.merge_attn_output_inplace(o1, lse1, o2, lse2)\n        else:\n            raise ValueError(f\"Invalid forward mode: {forward_mode}\")\n\n        return self.o_proj(output.reshape(b, s, self.num_heads * self.v_head_dim)), paged_kv_cache\n\n    def self_attn(  # pylint: disable=too-many-arguments\n        self,\n        q_nope: Tensor,\n        compressed_kv: Tensor,\n        q_pe: Tensor,\n        k_pe: Tensor,\n        paged_kv_cache: PagedKVCache,\n        layer_id: int,\n    ) -> Tuple[Tensor, Tensor]:\n        b, s, _, _ = q_nope.shape\n        q = op.concat(\n            [q_nope, q_pe], dim=-1\n        )  # (b, s, num_heads, qk_nope_head_dim + qk_rope_head_dim)\n        kv = op.reshape(\n            self.kv_b_proj(compressed_kv),\n            (b, s, self.num_heads, self.qk_nope_head_dim + self.v_head_dim),\n        )\n        k, v = op.split(kv, [self.qk_nope_head_dim], axis=-1)\n        k_pe = op.broadcast_to(k_pe, (b, s, self.num_heads, self.qk_rope_head_dim))\n        k = op.concat([k, k_pe], dim=-1)\n        output, lse = paged_kv_cache.self_attention(layer_id, q, k, v, self.softmax_scale)\n        return output, lse\n\n    def cross_attn(\n        self,\n        q_nope: Tensor,\n        q_pe: Tensor,\n        paged_kv_cache: PagedKVCache,\n        layer_id: int,\n    ) -> Tuple[Tensor, Tensor]:\n        b, s, _, _ = q_nope.shape\n        if not hasattr(self, \"w_uk_scale_inv\"):\n            q_nope = op.matmul(\n                q_nope.reshape(b * s, self.num_heads, self.qk_nope_head_dim).permute_dims(1, 0, 2),\n                self.w_uk.permute_dims(0, 2, 1),\n            )\n        else:\n            q_nope = batch_matmul.quantized_bmm(\n                q_nope.reshape(b * s, self.num_heads, self.qk_nope_head_dim).permute_dims(1, 0, 2),\n                self.w_uk,\n                self.w_uk_scale_inv,  # pylint: disable=no-member\n                self.block_size,\n            )\n        q_nope = q_nope.permute_dims(1, 0, 2).reshape(\n            b, s, self.num_heads, self.kv_lora_rank\n        )  # (b, s, num_heads, kv_lora_rank)\n        query_states = op.concat(\n            [q_nope, q_pe], dim=-1\n        )  # (b, s, num_heads, kv_lora_rank + qk_rope_head_dim)\n\n        output, lse = paged_kv_cache.cross_attention(\n            layer_id,\n            query_states,\n            v_head_dim=self.kv_lora_rank,\n            sm_scale=self.softmax_scale,\n        )  # (b, s, num_heads, kv_lora_rank)\n        if getattr(self, \"w_uv_scale_inv\", None) is None:\n            output = op.matmul(\n                output.reshape(b * s, self.num_heads, self.kv_lora_rank).permute_dims(1, 0, 2),\n                self.w_uv.permute_dims(0, 2, 1),\n            )\n        else:\n            output = batch_matmul.quantized_bmm(\n                output.reshape(b * s, self.num_heads, self.kv_lora_rank).permute_dims(1, 0, 2),\n                self.w_uv,\n                self.w_uv_scale_inv,  # pylint: disable=no-member\n                self.block_size,\n            )\n        output = output.permute_dims(1, 0, 2).reshape(b, s, self.num_heads * self.v_head_dim)\n        return output, lse\n\n\nclass DeepseekV2MoE(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: DeepseekV2Config):\n        super().__init__()\n        self.num_experts_per_tok = config.num_experts_per_tok\n        self.num_routed_experts = config.n_routed_experts\n        self.routed_scaling_factor = config.routed_scaling_factor\n        self.scoring_func = config.scoring_func\n        self.topk_method = config.topk_method\n        self.n_group = config.n_group\n        self.topk_group = config.topk_group\n\n        self.gate = nn.Linear(\n            config.hidden_size, self.num_routed_experts, bias=False, out_dtype=\"float32\"\n        )\n        self.e_score_correction_bias = (\n            nn.Parameter((config.n_routed_experts,), dtype=\"float32\")\n            if config.topk_method == \"noaux_tc\"\n            else None\n        )\n        self.norm_topk_prob = config.norm_topk_prob\n        if config.moe_intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MoE intermediate size {config.moe_intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.moe_intermediate_size = config.moe_intermediate_size // config.tensor_parallel_shards\n\n        self.moe_gate_up_proj = MixtralExperts(\n            self.num_routed_experts,\n            in_features=config.hidden_size,\n            out_features=2 * self.moe_intermediate_size,\n        )\n        self.moe_down_proj = MixtralExperts(\n            self.num_routed_experts,\n            in_features=self.moe_intermediate_size,\n            out_features=config.hidden_size,\n        )\n\n        self.shared_experts = DeepseekV2MLP(\n            config,\n            intermediate_size=config.moe_intermediate_size * config.n_shared_experts,\n        )\n        self.dtype = \"float32\"\n\n    def forward(self, x: Tensor):\n        def _expert_forward(x: Tensor, indptr: Tensor):\n            x1_x2 = self.moe_gate_up_proj(x, indptr)\n            x1, x2 = op.split(x1_x2, indices_or_sections=2, axis=-1)\n            x = self.moe_down_proj(op.silu(x1) * x2, indptr)\n            return x\n\n        experts_per_tok = self.num_experts_per_tok\n        num_experts = self.num_routed_experts\n        b, s, h = x.shape\n        num_tokens = b * s\n        x = op.reshape(x, (num_tokens, h))\n        logits = self.gate(x)  # (num_tokens, num_routed_experts)\n        assert logits.dtype == \"float32\"\n        if self.scoring_func == \"softmax\":\n            scores = op.softmax(logits, axis=-1)\n        elif self.scoring_func == \"sigmoid\":\n            scores = op.sigmoid(logits)\n        else:\n            raise ValueError(f\"Unsupported deepseek scoring function: {self.scoring_func}\")\n\n        # select top-k experts\n        if self.topk_method == \"greedy\":\n            expert_weights, expert_indices = op_ext.moe_misc.gating_topk(scores, experts_per_tok)\n        elif self.topk_method in [\"group_limited_greedy\", \"noaux_tc\"]:\n            expert_weights, expert_indices = op_ext.moe_misc.group_limited_greedy_topk(\n                scores,\n                self.num_experts_per_tok,\n                self.num_routed_experts,\n                self.n_group,\n                self.topk_group,\n                self.topk_method,\n                num_tokens,\n                self.e_score_correction_bias,\n            )\n        else:\n            raise ValueError(f\"Unsupported deepseek topk method: {self.topk_method}\")\n\n        if self.num_experts_per_tok > 1 and self.norm_topk_prob:\n            denominator = op.sum(expert_weights, axis=-1, keepdims=True) + 1e-20\n            expert_weights = expert_weights / denominator\n        expert_weights = expert_weights * self.routed_scaling_factor\n\n        use_cutlass = op_ext.get_store().cutlass_group_gemm and self.dtype in [\n            \"float16\",\n            \"bfloat16\",\n        ]\n\n        if num_tokens == 1:\n            # x: [num_tokens * experts_per_tok, hidden_size]\n            moe_hidden_states = _expert_forward(x, expert_indices)\n        else:\n            # cumsum: [num_tokens * local_experts]\n            cumsum = op_ext.moe_misc.moe_cumsum(expert_indices, num_experts)\n            # indices: [num_tokens * experts_per_tok]\n            reverse_indices, token_indices = op_ext.moe_misc.get_indices(cumsum, expert_indices)\n            if use_cutlass:\n                # indptr: [num_routed_experts]\n                indptr = op_ext.moe_misc.get_indptr(\n                    cumsum, num_experts, num_tokens, inclusive=True, out_dtype=\"int64\"\n                )\n            else:\n                # indptr: [num_routed_experts + 1]\n                indptr = op_ext.moe_misc.get_indptr(\n                    cumsum, num_experts, num_tokens, inclusive=False, out_dtype=\"int32\"\n                )\n            # x: [num_tokens * experts_per_tok, hidden_size]\n            moe_hidden_states = op.take(x, token_indices, axis=0)\n            moe_hidden_states = _expert_forward(moe_hidden_states, indptr)\n            moe_hidden_states = op_ext.moe_misc.scatter_output(moe_hidden_states, reverse_indices)\n\n        # moe_hidden_states: [num_tokens, experts_per_tok, hidden_size]\n        expert_weights = expert_weights.reshape(num_tokens, experts_per_tok, 1).astype(x.dtype)\n        moe_hidden_states = (\n            moe_hidden_states.reshape(num_tokens, experts_per_tok, h) * expert_weights\n        )\n        # moe_hidden_states: [num_tokens, hidden_size]\n        moe_hidden_states = op_ext.moe_misc.moe_sum(moe_hidden_states, dim=1)\n\n        shared_expert_hidden_states = self.shared_experts(x)\n\n        final_hidden_states = moe_hidden_states + shared_expert_hidden_states\n        final_hidden_states = op.reshape(final_hidden_states, (b, s, h))\n        return final_hidden_states\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        # Force e_score_correction_bias to be float32\n        if self.e_score_correction_bias is not None:\n            self.e_score_correction_bias.to(\"float32\")\n\n\nclass DeepseekV2DecoderLayer(nn.Module):\n    def __init__(self, config: DeepseekV2Config, layer_idx: int):\n        super().__init__()\n        self.self_attn = DeepseekV2Attention(config)\n        self.mlp = (\n            DeepseekV2MoE(config)\n            if (\n                config.n_routed_experts is not None\n                and layer_idx >= config.first_k_dense_replace\n                and layer_idx % config.moe_layer_freq == 0\n            )\n            else DeepseekV2MLP(config)\n        )\n        self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n        self.post_attention_layernorm = nn.RMSNorm(\n            config.hidden_size, -1, config.rms_norm_eps, bias=False\n        )\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.attrs[\"shard_strategy\"] = hint\n\n            if self.self_attn.q_lora_rank is None:\n                _set(\n                    self.self_attn.q_proj.weight,\n                    tp.ShardSingleDim(\"_shard_q_weight\", dim=0),\n                )\n            else:\n                _set(\n                    self.self_attn.q_b_proj.weight,\n                    tp.ShardSingleDim(\"_shard_q_b_weight\", dim=0),\n                )\n\n            _set(\n                self.self_attn.kv_b_proj.weight,\n                tp.ShardSingleDim(\"_shard_kv_b_weight\", dim=0),\n            )\n            _set(\n                self.self_attn.w_uk,\n                tp.ShardSingleDim(\"_shard_kv_b_weight_w_uk\", dim=0),\n            )\n            _set(\n                self.self_attn.w_uv,\n                tp.ShardSingleDim(\"_shard_kv_b_weight_w_uv\", dim=0),\n            )\n            _set(self.self_attn.o_proj.weight, tp.ShardSingleDim(\"_shard_o\", dim=1))\n\n            if isinstance(self.mlp, DeepseekV2MoE):\n                si = self.mlp.shared_experts.intermediate_size\n                mi = self.mlp.moe_intermediate_size\n                _set(\n                    self.mlp.shared_experts.gate_up_proj.weight,\n                    tp.ShardSingleDim(\"_shard_shared_experts_gate_up\", segs=[si, si], dim=0),\n                )\n                _set(\n                    self.mlp.shared_experts.down_proj.weight,\n                    tp.ShardSingleDim(\"_shard_shared_experts_down\", dim=1),\n                )\n                _set(\n                    self.mlp.moe_gate_up_proj.weight,\n                    tp.ShardSingleDim(\"_shard_moe_gate_up\", segs=[mi, mi], dim=1),\n                )\n                _set(\n                    self.mlp.moe_down_proj.weight,\n                    tp.ShardSingleDim(\"_shard_moe_mlp_down\", dim=2),\n                )\n            else:\n                assert isinstance(self.mlp, DeepseekV2MLP)\n                si = self.mlp.intermediate_size\n                _set(\n                    self.mlp.gate_up_proj.weight,\n                    tp.ShardSingleDim(\"_shard_gate_up\", segs=[si, si], dim=0),\n                )\n                _set(\n                    self.mlp.down_proj.weight,\n                    tp.ShardSingleDim(\"_shard_down\", dim=1),\n                )\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(  # pylint: disable=too-many-arguments\n        self,\n        hidden_states: Tensor,\n        paged_kv_cache: PagedKVCache,\n        layer_id: int,\n        query_positions: Tensor,\n        forward_mode: Literal[\"prefill\", \"decode\", \"extend\"],\n    ) -> Tuple[Tensor, PagedKVCache]:\n        out = self.input_layernorm(hidden_states)\n        out, paged_kv_cache = self.self_attn(\n            out, paged_kv_cache, layer_id, query_positions, forward_mode\n        )\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.post_attention_layernorm(hidden_states)\n        out = self.mlp(out)  # type: ignore[operator]\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states, paged_kv_cache\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass DeepseekV2Model(nn.Module):\n    def __init__(self, config: DeepseekV2Config):\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)\n        self.layers = nn.ModuleList(\n            [\n                DeepseekV2DecoderLayer(config, layer_idx)\n                for layer_idx in range(config.num_hidden_layers)\n            ]\n        )\n        self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n\n    def forward(\n        self,\n        inputs: Tensor,\n        paged_kv_cache: PagedKVCache,\n        forward_mode: Literal[\"prefill\", \"decode\", \"extend\"],\n    ):\n        hidden_states = inputs\n        query_positions = paged_kv_cache.get_query_positions(inputs.shape[0] * inputs.shape[1])\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states, paged_kv_cache = layer(\n                hidden_states, paged_kv_cache, layer_id, query_positions, forward_mode\n            )\n        hidden_states = self.norm(hidden_states)\n        return hidden_states, paged_kv_cache\n\n\nclass DeepseekV2ForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: DeepseekV2Config):\n        self.model = DeepseekV2Model(config)\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.dtype = config.dtype\n        self.hidden_size = config.hidden_size\n        self.num_hidden_layers = config.num_hidden_layers\n        self.intermediate_size = config.intermediate_size\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.kv_lora_rank = config.kv_lora_rank\n        self.qk_nope_head_dim = config.qk_nope_head_dim\n        self.qk_rope_head_dim = config.qk_rope_head_dim\n        self.v_head_dim = config.v_head_dim\n        self.rms_norm_eps = config.rms_norm_eps\n        self.rope_theta = config.rope_theta\n        self.vocab_size = config.vocab_size\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.weight_block_size = config.weight_block_size\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        forward_mode: Literal[\"prefill\", \"decode\", \"extend\"],\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states, paged_kv_cache = self.model(input_embeds, paged_kv_cache, forward_mode)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states, paged_kv_cache = self.model(input_embed, paged_kv_cache, \"prefill\")\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def extend(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states, paged_kv_cache = self.model(input_embed, paged_kv_cache, \"extend\")\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states, paged_kv_cache = self.model(input_embed, paged_kv_cache, \"decode\")\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits, paged_kv_cache = self.batch_forward(\n            input_embeds, paged_kv_cache, \"prefill\", logit_positions\n        )\n        return logits, paged_kv_cache\n\n    def batch_extend(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits, paged_kv_cache = self.batch_forward(\n            input_embeds, paged_kv_cache, \"extend\", logit_positions\n        )\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits, paged_kv_cache = self.batch_forward(input_embeds, paged_kv_cache, \"decode\", None)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits, paged_kv_cache = self.batch_forward(input_embeds, paged_kv_cache, \"extend\", None)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mla\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=1,\n            qk_head_dim=self.kv_lora_rank + self.qk_rope_head_dim,\n            v_head_dim=self.kv_lora_rank,\n            mla_original_qk_head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim,\n            mla_original_v_head_dim=self.v_head_dim,\n            rope_mode=RopeMode.NONE,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"extend\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_extend\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/eagle/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/eagle/eagle_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's EAGLE parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\nfrom mlc_llm.quantization import Quantization, make_awq_quant\n\nfrom .eagle_model import EagleConfig, EagleForCausalLM\n\nawq_quant = make_awq_quant(EagleForCausalLM)\n\n\nhuggingface = make_standard_hf_loader(\n    model_cls=EagleForCausalLM,\n    layer_prefix=\"layers\",\n    add_unused=[\"rotary_emb.inv_freq\"],\n)\n\n\ndef awq(model_config: EagleConfig, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of AWQ parameters.\n    Parameters\n    ----------\n    model_config : EagleConfig\n        The configuration of the Eagle model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to AWQ.\n    \"\"\"\n    model, _ = awq_quant(model_config, quantization)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),  # type: ignore[attr-defined]\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    for i in range(model_config.num_hidden_layers):\n        # Add QKV in self attention\n        attn = f\"layers.{i}.self_attn\"\n        for quantize_suffix in [\"qweight\", \"qzeros\", \"scales\"]:\n            mlc_name = f\"{attn}.qkv_proj.{quantize_suffix}\"\n            assert mlc_name in named_parameters\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{attn}.q_proj.{quantize_suffix}\",\n                    f\"{attn}.k_proj.{quantize_suffix}\",\n                    f\"{attn}.v_proj.{quantize_suffix}\",\n                ],\n                functools.partial(\n                    lambda q, k, v, dtype: np.concatenate(\n                        [q, k, v],\n                        axis=1,  # AWQ GEMM would transpose the weight\n                    ).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n        # Concat gate and up in MLP\n        mlp = f\"layers.{i}.mlp\"\n        for quantize_suffix in [\"qweight\", \"qzeros\", \"scales\"]:\n            mlc_name = f\"{mlp}.gate_up_proj.{quantize_suffix}\"\n            assert mlc_name in named_parameters\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{mlp}.gate_proj.{quantize_suffix}\",\n                    f\"{mlp}.up_proj.{quantize_suffix}\",\n                ],\n                functools.partial(\n                    lambda gate, up, dtype: np.concatenate(\n                        [gate, up],\n                        axis=1,  # AWQ GEMM would transpose the weight\n                    ).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n        # inv_freq is not used in the model\n        mapping.add_unused(f\"{attn}.rotary_emb.inv_freq\")\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype),\n            )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/eagle/eagle_model.py",
    "content": "\"\"\"\nImplementation for EAGLE architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Optional\n\nfrom tvm import tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.model.llama.llama_model import LlamaAttention, LlamaConfig, LlamaFFN\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass EagleConfig(LlamaConfig):\n    \"\"\"Configuration of the Eagle model.\"\"\"\n\n    bias: bool = True  # Whether to use bias in the fc layers\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass EagleDecoderLayer(nn.Module):\n    def __init__(self, config: EagleConfig, index: int):\n        rms_norm_eps = config.rms_norm_eps\n        self.self_attn = LlamaAttention(config)\n        self.mlp = LlamaFFN(config)\n        self.index = index\n        if self.index != 0:\n            self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)\n        self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.weight.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_q_heads * hd\n            k = self.self_attn.num_kv_heads * hd\n            v = self.self_attn.num_kv_heads * hd\n            i = self.mlp.intermediate_size\n            _set(\n                self.self_attn.qkv_proj,\n                tp.ShardSingleDim(\"_shard_qkv\", segs=[q, k, v], dim=0),\n            )\n            _set(self.self_attn.o_proj, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.gate_up_proj,\n                tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n            )\n            _set(self.mlp.down_proj, tp.ShardSingleDim(\"_shard_mlp_down\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        if self.index != 0:\n            hidden_states = self.input_layernorm(hidden_states)\n        out = self.self_attn(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.mlp(self.post_attention_layernorm(hidden_states))\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass EagleForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: EagleConfig):\n        # Put the model definition here to align with EAGLE's original structure\n        assert config.hidden_size % config.num_attention_heads == 0\n        self.embed_tokens = nn.Embedding(\"vocab_size\", config.hidden_size)\n        self.layers = nn.ModuleList(\n            [EagleDecoderLayer(config, i) for i in range(config.num_hidden_layers)]\n        )\n        self.fc = nn.Linear(\n            in_features=2 * config.hidden_size,\n            out_features=config.hidden_size,\n            bias=config.bias,\n        )\n\n        self.num_hidden_layers = config.num_hidden_layers\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.head_dim = config.head_dim\n        self.hidden_size = config.hidden_size\n        self.vocab_size = config.vocab_size\n        self.rope_theta = config.position_embedding_base\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.dtype = \"float32\"\n\n    def fuse_embed_hidden_states(self, input_embed: Tensor, hidden_states: Tensor):\n        hidden_states = op.concat([input_embed, hidden_states], dim=-1)\n        hidden_states = self.fc(hidden_states)\n        return hidden_states\n\n    def forward_to_last_hidden_states(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache):\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        return hidden_states\n\n    def forward(self, input_embed: Tensor, hidden_states: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = self.fuse_embed_hidden_states(input_embed, hidden_states)\n        hidden_states = self.forward_to_last_hidden_states(hidden_states, paged_kv_cache)\n        return hidden_states\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        hidden_states: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.forward_to_last_hidden_states(hidden_states, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        return hidden_states\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.embed_tokens(input_ids)\n\n    def prefill_to_last_hidden_states(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.forward_to_last_hidden_states(hidden_states, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def decode_to_last_hidden_states(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.forward_to_last_hidden_states(hidden_states, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def batch_prefill_to_last_hidden_states(\n        self,\n        hidden_states: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        hidden_states = self.batch_forward(hidden_states, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def batch_decode_to_last_hidden_states(\n        self, hidden_states: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        hidden_states = self.batch_forward(hidden_states, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"fuse_embed_hidden_states\": {\n                \"input_embed\": nn.spec.Tensor([\"seq_len\", self.hidden_size], self.dtype),\n                \"hidden_states\": nn.spec.Tensor([\"seq_len\", self.hidden_size], self.dtype),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill_to_last_hidden_states\": {\n                \"hidden_states\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode_to_last_hidden_states\": {\n                \"hidden_states\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill_to_last_hidden_states\": {\n                \"hidden_states\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode_to_last_hidden_states\": {\n                \"hidden_states\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/gemma/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/gemma/gemma_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Gemma parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\nfrom mlc_llm.quantization import Quantization\n\nfrom .gemma_model import GemmaConfig, GemmaForCausalLM\n\n\ndef huggingface(model_config: GemmaConfig, quantization: Quantization) -> ExternMapping:\n    \"\"\"Create HF weight mapping for Gemma.\"\"\"\n    model = GemmaForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n    base_loader = make_standard_hf_loader(\n        model_cls=GemmaForCausalLM,\n    )\n    mapping = base_loader(model_config, quantization)\n\n    def add_one(name: str) -> None:\n        mlc_param = named_parameters[name]\n        mapping.add_mapping(\n            name,\n            [name],\n            functools.partial(\n                lambda x, dtype: (x + 1).astype(dtype),\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n    for i in range(model_config.num_hidden_layers):\n        add_one(f\"model.layers.{i}.input_layernorm.weight\")\n        add_one(f\"model.layers.{i}.post_attention_layernorm.weight\")\n\n    add_one(\"model.norm.weight\")\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/gemma/gemma_model.py",
    "content": "\"\"\"Implementation for Gemma architecture.\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass GemmaConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Gemma model.\"\"\"\n\n    hidden_size: int\n    intermediate_size: int\n    attention_bias: bool\n    num_attention_heads: int\n    num_key_value_heads: int\n    head_dim: int\n    num_hidden_layers: int\n    rms_norm_eps: float\n    vocab_size: int\n    hidden_activation: Optional[str] = None\n    position_embedding_base: int = 0\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    max_batch_size: int = 1\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.hidden_activation is None:\n            self.hidden_activation = self.kwargs.get(\"hidden_act\", None)\n        if self.hidden_activation not in (\"gelu\", \"gelu_pytorch_tanh\"):\n            raise ValueError(\"Only GeLU is supported as the activation for gemma.\")\n        if self.attention_bias:\n            raise ValueError('Only \"False\" attention_bias is supported for gemma')\n        if self.position_embedding_base == 0:\n            if \"rope_theta\" in self.kwargs:\n                self.position_embedding_base = self.kwargs.pop(\"rope_theta\")\n            else:\n                self.position_embedding_base = 10000\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        assert self.num_attention_heads % self.num_key_value_heads == 0\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass GemmaEmbedding(nn.Embedding):\n    \"\"\"The embedding module specialized for Gemma so that\n    it can be shared with the final lm_head.\n    \"\"\"\n\n    def lm_head_forward(self, x: nn.Tensor):\n        \"\"\"The lm_head forwarding, which transposes the weight and multiplies\n        with the input tensor.\n        \"\"\"\n        weight = nn.op.permute_dims(self.weight)\n        return nn.op.matmul(x, weight, out_dtype=\"float32\")\n\n\nclass GemmaMLP(nn.Module):\n    def __init__(self, config: GemmaConfig):\n        super().__init__()\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.gate_up_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=2 * self.intermediate_size,\n            bias=False,\n        )\n        self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)\n\n    def forward(self, x: Tensor):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(op.gelu(x1, approximate=\"tanh\") * x2)\n\n\nclass GemmaAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: GemmaConfig):\n        self.head_dim = config.head_dim\n        self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards\n        assert (\n            config.num_key_value_heads % config.tensor_parallel_shards == 0\n        ), f\"num_kv_heads({config.num_key_value_heads}) must be divisible by tensor_parallel_shards\"\n        assert (\n            config.num_key_value_heads >= config.tensor_parallel_shards\n        ), f\"Too large tensor_parallel_shards, must be smaller than {config.num_key_value_heads}\"\n        self.num_kv_heads = config.num_key_value_heads // config.tensor_parallel_shards\n        self.qkv_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim,\n            bias=config.attention_bias,\n        )\n        self.o_proj = nn.Linear(\n            in_features=self.num_q_heads * self.head_dim,\n            out_features=config.hidden_size,\n            bias=config.attention_bias,\n        )\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads\n        b, s, _ = hidden_states.shape\n        # QKV Projection\n        qkv = self.qkv_proj(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        # Attention\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_q_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        return self.o_proj(output)\n\n\nclass GemmaDecoderLayer(nn.Module):\n    def __init__(self, config: GemmaConfig):\n        rms_norm_eps = config.rms_norm_eps\n        self.self_attn = GemmaAttention(config)\n        self.mlp = GemmaMLP(config)\n        # Gemma RMSNorm adds 1 to the weights. It is already fused in the loader\n        self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)\n        self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.weight.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_q_heads * hd\n            k = self.self_attn.num_kv_heads * hd\n            v = self.self_attn.num_kv_heads * hd\n            i = self.mlp.intermediate_size\n            _set(\n                self.self_attn.qkv_proj,\n                tp.ShardSingleDim(\"_shard_qkv\", segs=[q, k, v], dim=0),\n            )\n            _set(self.self_attn.o_proj, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.gate_up_proj,\n                tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n            )\n            _set(self.mlp.down_proj, tp.ShardSingleDim(\"_shard_mlp_down\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.mlp(self.post_attention_layernorm(hidden_states))\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass GemmaModel(nn.Module):\n    def __init__(self, config: GemmaConfig):\n        self.hidden_size = config.hidden_size\n        assert config.hidden_size % config.num_attention_heads == 0\n        self.embed_tokens = GemmaEmbedding(\"vocab_size\", config.hidden_size)\n        self.layers = nn.ModuleList(\n            [GemmaDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n\n    def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = input_embed\n        hidden_states = hidden_states * (self.hidden_size**0.5)\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass GemmaForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: GemmaConfig):\n        self.model = GemmaModel(config)\n        self.num_hidden_layers = config.num_hidden_layers\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.head_dim = config.head_dim\n        self.hidden_size = config.hidden_size\n        self.vocab_size = config.vocab_size\n        self.rope_theta = config.position_embedding_base\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def get_logits(self, hidden_states: Tensor):\n        logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.get_logits(hidden_states)\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.get_logits(hidden_states)\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.get_logits(hidden_states)\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/gemma2/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/gemma2/gemma2_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Gemma2 parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\nfrom mlc_llm.quantization import Quantization\n\nfrom .gemma2_model import Gemma2Config, Gemma2ForCausalLM\n\n\ndef huggingface(model_config: Gemma2Config, quantization: Quantization) -> ExternMapping:\n    \"\"\"Create HF weight mapping for Gemma2.\"\"\"\n    model = Gemma2ForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n    base_loader = make_standard_hf_loader(\n        model_cls=Gemma2ForCausalLM,\n    )\n    mapping = base_loader(model_config, quantization)\n\n    def add_one(name: str) -> None:\n        mlc_param = named_parameters[name]\n        mapping.add_mapping(\n            name,\n            [name],\n            functools.partial(\n                lambda x, dtype: (x + 1).astype(dtype),\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n    for i in range(model_config.num_hidden_layers):\n        add_one(f\"model.layers.{i}.input_layernorm.weight\")\n        add_one(f\"model.layers.{i}.post_attention_layernorm.weight\")\n        add_one(f\"model.layers.{i}.pre_feedforward_layernorm.weight\")\n        add_one(f\"model.layers.{i}.post_feedforward_layernorm.weight\")\n\n    add_one(\"model.norm.weight\")\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/gemma2/gemma2_model.py",
    "content": "\"\"\"Implementation for Gemma2 architecture.\"\"\"\n\nimport dataclasses\n\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm.model.gemma.gemma_model import (\n    GemmaAttention,\n    GemmaConfig,\n    GemmaForCausalLM,\n    GemmaMLP,\n    GemmaModel,\n)\nfrom mlc_llm.nn import PagedKVCache\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass Gemma2Config(GemmaConfig):\n    \"\"\"Configuration of the Gemma2 model, in addition to the Gemma model\"\"\"\n\n    # NOTE: We ignore attn_logit_softcapping in the gemma2 implementation for now.\n    # The Gemma 2 team observed minor differences when soft-capping is removed during inference,\n    # according to https://huggingface.co/blog/gemma2.\n    # The soft-capping is also not supported by HuggingFace transformers `Gemma2SdpaAttention`.\n    attn_logit_softcapping: float = None\n    final_logit_softcapping: float = None\n    query_pre_attn_scalar: int = None\n    sliding_window: int = None\n\n    def __post_init__(self):\n        super().__post_init__()\n        # NOTE: override the context window size with the Gemma2 sliding window size,\n        # as the sliding window attention every other layer is yet to be supported.\n        self.context_window_size = self.sliding_window\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass Gemma2Attention(GemmaAttention):\n    def __init__(self, config: Gemma2Config):\n        super().__init__(config)\n        self.scaling_factor = (config.head_dim / config.query_pre_attn_scalar) ** 0.5\n\n\nclass Gemma2DecoderLayer(nn.Module):\n    def __init__(self, config: Gemma2Config):\n        rms_norm_eps = config.rms_norm_eps\n        self.self_attn = Gemma2Attention(config)\n        self.mlp = GemmaMLP(config)\n        # Gemma RMSNorm adds 1 to the weights. It is already fused in the loader\n        self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)\n        self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)\n        self.pre_feedforward_layernorm = nn.RMSNorm(\n            config.hidden_size, -1, rms_norm_eps, bias=False\n        )\n        self.post_feedforward_layernorm = nn.RMSNorm(\n            config.hidden_size, -1, rms_norm_eps, bias=False\n        )\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.weight.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_q_heads * hd\n            k = self.self_attn.num_kv_heads * hd\n            v = self.self_attn.num_kv_heads * hd\n            i = self.mlp.intermediate_size\n            _set(\n                self.self_attn.qkv_proj,\n                tp.ShardSingleDim(\"_shard_qkv\", segs=[q, k, v], dim=0),\n            )\n            _set(self.self_attn.o_proj, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.gate_up_proj,\n                tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n            )\n            _set(self.mlp.down_proj, tp.ShardSingleDim(\"_shard_mlp_down\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id)\n        out = self._apply_post_matmul_norm(out, norm=self.post_attention_layernorm)\n        hidden_states = out + hidden_states\n\n        out = self.pre_feedforward_layernorm(hidden_states)\n        out = self.mlp(out)\n        out = self._apply_post_matmul_norm(out, norm=self.post_feedforward_layernorm)\n        hidden_states = out + hidden_states\n\n        return hidden_states\n\n    def _apply_post_matmul_norm(self, out: Tensor, norm: nn.Tensor):\n        if self.tensor_parallel_shards > 1:\n            return norm(op.ccl_allreduce(out, \"sum\"))\n        return norm(out)\n\n\nclass Gemma2Model(GemmaModel):\n    def __init__(self, config: Gemma2Config):\n        super().__init__(config)\n        self.layers = nn.ModuleList(\n            [Gemma2DecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n\n\nclass Gemma2ForCausalLM(GemmaForCausalLM):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: Gemma2Config):\n        super().__init__(config)\n        self.model = Gemma2Model(config)\n        self.final_logit_softcapping = config.final_logit_softcapping\n\n    def get_logits(self, hidden_states: Tensor):\n        logits = super().get_logits(hidden_states)\n        if self.final_logit_softcapping is not None:\n            logits = op.tanh(logits / self.final_logit_softcapping) * self.final_logit_softcapping\n        return logits\n"
  },
  {
    "path": "python/mlc_llm/model/gemma3/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/gemma3/gemma3_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Gemma3 parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\nfrom mlc_llm.quantization import Quantization\n\nfrom .gemma3_model import Gemma3Config, Gemma3ForCausalLM\n\n\ndef huggingface(model_config: Gemma3Config, quantization: Quantization) -> ExternMapping:\n    \"\"\"Create HF weight mapping for Gemma3.\"\"\"\n    model = Gemma3ForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n    mlc_prefix = \"language_model.\"\n    if model_config.is_text_model:\n        hf_prefix = \"\"\n    else:\n        hf_prefix = \"language_model.\"\n\n    def name_transform(name: str) -> str:\n        if name.startswith(mlc_prefix):\n            name = name[len(mlc_prefix) :]\n        return f\"{hf_prefix}{name}\"\n\n    def num_layers(config: object) -> int:\n        return config.text_config.num_hidden_layers  # type: ignore[attr-defined]\n\n    base_loader = make_standard_hf_loader(\n        model_cls=Gemma3ForCausalLM,\n        include_qkv=False,\n        include_gate_up=True,\n        gate_up_target_name=\"gate_up_proj\",\n        num_layers_getter=num_layers,\n        layer_prefix=f\"{mlc_prefix}model.layers\",\n        name_transform=name_transform,\n    )\n    mapping = base_loader(model_config, quantization)\n\n    def add_one(name: str) -> None:\n        mlc_param = named_parameters[mlc_prefix + name]\n        mapping.add_mapping(\n            mlc_prefix + name,\n            [name_transform(mlc_prefix + name)],\n            functools.partial(\n                lambda x, dtype: (x + 1).astype(dtype),\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n    for i in range(model_config.text_config.num_hidden_layers):\n        add_one(f\"model.layers.{i}.input_layernorm.weight\")\n        add_one(f\"model.layers.{i}.post_attention_layernorm.weight\")\n        add_one(f\"model.layers.{i}.pre_feedforward_layernorm.weight\")\n        add_one(f\"model.layers.{i}.post_feedforward_layernorm.weight\")\n        add_one(f\"model.layers.{i}.self_attn.k_norm.weight\")\n        add_one(f\"model.layers.{i}.self_attn.q_norm.weight\")\n\n    add_one(\"model.norm.weight\")\n\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/gemma3/gemma3_model.py",
    "content": "\"\"\"Implementation for Gemma3 architecture.\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.model.gemma.gemma_model import GemmaEmbedding\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass Gemma3TextConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the text model inside Gemma3\"\"\"\n\n    # NOTE More fields have defaults due to Huggingface Gemma3 configs missing fields\n    # The defaults for these fields can be found in the transformers library\n    hidden_size: int\n    intermediate_size: int\n    num_hidden_layers: int\n    attention_bias: bool = False\n    num_attention_heads: int = 8\n    num_key_value_heads: int = 4\n    head_dim: int = 256\n    rms_norm_eps: float = 1e-6\n    hidden_activation: Optional[str] = \"gelu_pytorch_tanh\"\n    position_embedding_base: int = 1_000_000\n    rope_scaling: int = 0\n    context_window_size: int = 131_072\n    prefill_chunk_size: int = 0\n\n    query_pre_attn_scalar: int = 256\n    sliding_window_size: int = None\n    sliding_window_pattern = 6\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.hidden_activation is None:\n            self.hidden_activation = self.kwargs.get(\"hidden_act\", None)\n        if self.sliding_window_size is None:\n            self.sliding_window_size = self.kwargs.get(\"sliding_window\", None)\n        if self.hidden_activation not in (\"gelu\", \"gelu_pytorch_tanh\"):\n            raise ValueError(\"Only GeLU is supported as the activation for gemma.\")\n        if self.attention_bias:\n            raise ValueError('Only \"False\" attention_bias is supported for gemma')\n        if self.position_embedding_base == 1000000 and \"rope_theta\" in self.kwargs:\n            self.position_embedding_base = self.kwargs.pop(\"rope_theta\")\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        assert self.num_attention_heads % self.num_key_value_heads == 0\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        # NOTE: override the context window size with the Gemma2 sliding window size,\n        # as the sliding window attention every other layer is yet to be supported.\n        self.context_window_size = max(self.sliding_window_size, 8192)\n\n\n@dataclasses.dataclass\nclass Gemma3Config(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Gemma3 model\"\"\"\n\n    text_config: Gemma3TextConfig = None\n    vocab_size: int = 262_208\n    tensor_parallel_shards: int = 1\n    max_batch_size: int = 1\n    context_window_size: int = -1\n    sliding_window_size: int = -1\n    prefill_chunk_size: int = -1\n    is_text_model: bool = False\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.text_config is None:\n            self.is_text_model = True\n            self.text_config = Gemma3TextConfig.from_dict(self.kwargs)\n\n        text_config_dict: Dict[str, Any]  # type: ignore\n        if isinstance(self.text_config, Gemma3TextConfig):\n            text_config_dict = dataclasses.asdict(self.text_config)\n        else:\n            text_config_dict = dict(self.text_config)\n\n        for k, v in text_config_dict.pop(\"kwargs\", {}).items():\n            text_config_dict[k] = v\n\n        self.text_config = Gemma3TextConfig.from_dict(text_config_dict)\n\n        for k in [\"context_window_size\", \"prefill_chunk_size\", \"sliding_window_size\"]:\n            if getattr(self, k) <= 0:\n                if hasattr(self.text_config, k):\n                    setattr(self, k, getattr(self.text_config, k))\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass Gemma3MLP(nn.Module):\n    def __init__(self, config: Gemma3Config):\n        super().__init__()\n        if config.text_config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.text_config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = (\n            config.text_config.intermediate_size // config.tensor_parallel_shards\n        )\n        self.gate_up_proj = nn.Linear(\n            in_features=config.text_config.hidden_size,\n            out_features=2 * self.intermediate_size,\n            bias=False,\n        )\n        self.down_proj = nn.Linear(\n            self.intermediate_size, config.text_config.hidden_size, bias=False\n        )\n\n    def forward(self, x: Tensor):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(op.gelu(x1, approximate=\"tanh\") * x2)\n\n\nclass Gemma3Attention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: Gemma3Config):\n        self.head_dim = config.text_config.head_dim\n        self.num_q_heads = config.text_config.num_attention_heads // config.tensor_parallel_shards\n        self.num_kv_heads = config.text_config.num_key_value_heads\n        assert (\n            self.num_kv_heads % config.tensor_parallel_shards == 0\n        ), f\"num_kv_heads({self.num_kv_heads}) must be divisible by tensor_parallel_shards\"\n        assert (\n            self.num_kv_heads >= config.tensor_parallel_shards\n        ), f\"Too large tensor_parallel_shards, must be smaller than {self.num_kv_heads}\"\n        self.num_kv_heads = self.num_kv_heads // config.tensor_parallel_shards\n        self.q_proj = nn.Linear(\n            in_features=config.text_config.hidden_size,\n            out_features=self.num_q_heads * self.head_dim,\n            bias=config.text_config.attention_bias,\n        )\n        self.k_proj = nn.Linear(\n            in_features=config.text_config.hidden_size,\n            out_features=self.num_kv_heads * self.head_dim,\n            bias=config.text_config.attention_bias,\n        )\n        self.v_proj = nn.Linear(\n            in_features=config.text_config.hidden_size,\n            out_features=self.num_kv_heads * self.head_dim,\n            bias=config.text_config.attention_bias,\n        )\n        self.o_proj = nn.Linear(\n            in_features=self.num_q_heads * self.head_dim,\n            out_features=config.text_config.hidden_size,\n            bias=config.text_config.attention_bias,\n        )\n        self.q_norm = nn.RMSNorm(\n            config.text_config.head_dim, -1, config.text_config.rms_norm_eps, bias=False\n        )\n        self.k_norm = nn.RMSNorm(\n            config.text_config.head_dim, -1, config.text_config.rms_norm_eps, bias=False\n        )\n        # self.scaling_factor = (self.head_dim / config.text_config.query_pre_attn_scalar) ** 0.5\n        self.scaling = config.text_config.query_pre_attn_scalar**-0.5\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q = self.head_dim, self.num_q_heads\n        b, s, _ = hidden_states.shape\n        # QKV Projection\n        q_proj = op.reshape(self.q_proj(hidden_states), (b, s, -1, d))\n        k_proj = op.reshape(self.k_proj(hidden_states), (b, s, -1, d))\n        v_proj = op.reshape(self.v_proj(hidden_states), (b, s, -1, d))\n\n        q_norm = self.q_norm(q_proj)\n        k_norm = self.k_norm(k_proj)\n\n        qkv = op.concat([q_norm, k_norm, v_proj], dim=2)\n\n        # Attention\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_q_heads, sm_scale=self.scaling\n            ),\n            (b, s, h_q * d),\n        )\n        return self.o_proj(output)\n\n\nclass Gemma3DecoderLayer(nn.Module):\n    def __init__(self, config: Gemma3Config):\n        rms_norm_eps = config.text_config.rms_norm_eps\n        self.self_attn = Gemma3Attention(config)\n        self.mlp = Gemma3MLP(config)\n        # Gemma RMSNorm adds 1 to the weights. It is already fused in the loader\n        self.input_layernorm = nn.RMSNorm(\n            config.text_config.hidden_size, -1, rms_norm_eps, bias=False\n        )\n        self.post_attention_layernorm = nn.RMSNorm(\n            config.text_config.hidden_size, -1, rms_norm_eps, bias=False\n        )\n        self.pre_feedforward_layernorm = nn.RMSNorm(\n            config.text_config.hidden_size, -1, rms_norm_eps, bias=False\n        )\n        self.post_feedforward_layernorm = nn.RMSNorm(\n            config.text_config.hidden_size, -1, rms_norm_eps, bias=False\n        )\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.weight.attrs[\"shard_strategy\"] = hint\n\n            i = self.mlp.intermediate_size\n            _set(self.self_attn.q_proj, tp.ShardSingleDim(\"_shard_q\", dim=0))\n            _set(self.self_attn.k_proj, tp.ShardSingleDim(\"_shard_k\", dim=0))\n            _set(self.self_attn.v_proj, tp.ShardSingleDim(\"_shard_v\", dim=0))\n            _set(self.self_attn.q_norm, tp.ShardSingleDim(\"_shard_q_norm\", dim=0))\n            _set(self.self_attn.k_norm, tp.ShardSingleDim(\"_shard_k_norm\", dim=0))\n            _set(self.self_attn.o_proj, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.gate_up_proj,\n                tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n            )\n            _set(self.mlp.down_proj, tp.ShardSingleDim(\"_shard_mlp_down\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id)\n        out = self._apply_post_matmul_norm(out, norm=self.post_attention_layernorm)\n        hidden_states = out + hidden_states\n\n        out = self.pre_feedforward_layernorm(hidden_states)\n        out = self.mlp(out)\n        out = self._apply_post_matmul_norm(out, norm=self.post_feedforward_layernorm)\n        hidden_states = out + hidden_states\n\n        return hidden_states\n\n    def _apply_post_matmul_norm(self, out: Tensor, norm: nn.Tensor):\n        if self.tensor_parallel_shards > 1:\n            return norm(op.ccl_allreduce(out, \"sum\"))\n        return norm(out)\n\n\nclass Gemma3TextModel(nn.Module):\n    def __init__(self, config: Gemma3Config):\n        self.hidden_size = config.text_config.hidden_size\n        assert config.text_config.hidden_size % config.text_config.num_attention_heads == 0\n        self.embed_tokens = GemmaEmbedding(\"vocab_size\", config.text_config.hidden_size)\n        self.layers = nn.ModuleList(\n            [Gemma3DecoderLayer(config) for _ in range(config.text_config.num_hidden_layers)]\n        )\n        self.norm = nn.RMSNorm(\n            config.text_config.hidden_size,\n            -1,\n            config.text_config.rms_norm_eps,\n            bias=False,\n        )\n\n    def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = input_embed\n        hidden_states = hidden_states * (self.hidden_size**0.5)\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass Gemma3LanguageModel(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: Gemma3Config):\n        self.model = Gemma3TextModel(config)\n        self.config = config\n        self.num_hidden_layers = config.text_config.num_hidden_layers\n        self.num_attention_heads = config.text_config.num_attention_heads\n        self.num_key_value_heads = config.text_config.num_key_value_heads\n        self.head_dim = config.text_config.head_dim\n        self.hidden_size = config.text_config.hidden_size\n        self.vocab_size = config.vocab_size\n        self.rope_theta = config.text_config.position_embedding_base\n        self.rope_scaling = config.text_config.rope_scaling\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def get_logits(self, hidden_states: Tensor):\n        logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.get_logits(hidden_states)\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.get_logits(hidden_states)\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.get_logits(hidden_states)\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        # if \"factor\" in self.rope_scaling:\n        #     rope_scaling = self.rope_scaling[\"factor\"]\n        # else:\n        #     rope_scaling = 1\n        return PagedKVCache.create_generic(\n            attn_kind=[\n                (\n                    \"mha_sliding\"\n                    if ((i + 1) % self.config.text_config.sliding_window_pattern)\n                    else \"mha\"\n                )\n                for i in range(self.num_hidden_layers)\n            ],\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n\n\nclass Gemma3ForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: Gemma3Config):\n        super().__init__()\n        self.config = config\n        self.language_model = Gemma3LanguageModel(config)\n        self.vocab_size = config.vocab_size\n        self.dtype = \"float32\"\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        self.language_model.to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def get_logits(self, hidden_states: Tensor):\n        logits = self.language_model.model.embed_tokens.lm_head_forward(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.language_model.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.get_logits(hidden_states)\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.language_model.model.embed_tokens(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.language_model.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.get_logits(hidden_states)\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.language_model.model(input_embed, paged_kv_cache)\n        logits = self.get_logits(hidden_states)\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        # if \"factor\" in self.language_model.rope_scaling:\n        #     rope_scaling = self.language_model.rope_scaling[\"factor\"]\n        # else:\n        #     rope_scaling = 1\n        return PagedKVCache.create_generic(\n            attn_kind=[\n                (\n                    \"mha_sliding\"\n                    if ((i + 1) % self.config.text_config.sliding_window_pattern)\n                    else \"mha\"\n                )\n                for i in range(self.language_model.num_hidden_layers)\n            ],\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.language_model.num_hidden_layers,\n            num_attention_heads=self.language_model.num_attention_heads\n            // self.tensor_parallel_shards,\n            num_key_value_heads=self.language_model.num_key_value_heads\n            // self.tensor_parallel_shards,\n            qk_head_dim=self.language_model.head_dim,\n            v_head_dim=self.language_model.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.language_model.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor(\n                    [1, \"seq_len\", self.language_model.hidden_size], self.dtype\n                ),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.language_model.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor(\n                    [1, \"seq_len\", self.language_model.hidden_size], self.dtype\n                ),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor(\n                    [\"batch_size\", 1, self.language_model.hidden_size], self.dtype\n                ),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor(\n                    [1, \"seq_len\", self.language_model.hidden_size], self.dtype\n                ),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/gpt2/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/gpt2/gpt2_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's GPT-2 parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.quantization import Quantization\n\nfrom .gpt2_model import GPT2Config, GPT2LMHeadModel\n\n\ndef huggingface(model_config: GPT2Config, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : GPT2Config\n        The configuration of the GPT-2 model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = GPT2LMHeadModel(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    mapping.add_mapping(\n        \"lm_head.weight\",\n        [\"wte.weight\"],\n        functools.partial(\n            lambda x, dtype: x.astype(dtype),\n            dtype=named_parameters[\"transformer.wte.weight\"].dtype,\n        ),\n    )\n\n    for i in range(model_config.n_layer):\n        mapping.add_unused(f\"h.{i}.attn.bias\")\n\n        # Transpose c_attn, c_proj and c_fc weights since GPT-2 uses Conv1D\n        for conv1d_weight_name in [\n            \"attn.c_attn\",\n            \"attn.c_proj\",\n            \"mlp.c_proj\",\n            \"mlp.c_fc\",\n        ]:\n            src_name = f\"h.{i}.{conv1d_weight_name}.weight\"\n            mlc_name = f\"transformer.{src_name}\"\n            mapping.add_mapping(\n                mlc_name,\n                [src_name],\n                functools.partial(\n                    lambda x, dtype: x.transpose().astype(dtype),\n                    dtype=named_parameters[mlc_name].dtype,\n                ),\n            )\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            # transformer.h.0.attn.c_attn.weight --> h.0.attn.c_attn.weight\n            source_name = mlc_name.split(\".\", 1)[1]\n            mapping.add_mapping(\n                mlc_name,\n                [source_name],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/gpt2/gpt2_model.py",
    "content": "\"\"\"\nImplementation for GPT-2 architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass GPT2Config(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the GPT-2 model.\"\"\"\n\n    vocab_size: int\n    n_embd: int\n    n_layer: int\n    n_head: int\n    layer_norm_epsilon: float\n    n_inner: int = -1\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    scale_attn_by_inverse_layer_idx: bool = False\n    tensor_parallel_shards: int = 1\n    head_dim: int = 0\n    max_batch_size: int = 1\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.n_inner is None or self.n_inner == -1:\n            self.n_inner = 4 * self.n_embd\n        if self.context_window_size == 0:\n            for name in [\"n_positions\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `n_positions` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.head_dim == 0:\n            self.head_dim = self.n_embd // self.n_head\n        assert self.head_dim * self.n_head == self.n_embd\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring,too-many-locals\n\n\nclass GPT2Attention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: GPT2Config):\n        self.embed_dim = config.n_embd\n        if config.n_head % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.n_head} attention heads \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.num_heads = config.n_head // config.tensor_parallel_shards\n        self.head_dim = config.head_dim\n        self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx\n\n        self.c_attn = nn.Linear(\n            in_features=self.embed_dim,\n            out_features=3 * self.num_heads * self.head_dim,\n            bias=True,\n        )\n        self.c_proj = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=True)\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h = self.head_dim, self.num_heads\n        b, s, _ = hidden_states.shape\n\n        qkv = self.c_attn(hidden_states)\n        qkv = op.reshape(qkv, (b, s, 3 * h, d))\n\n        if self.scale_attn_by_inverse_layer_idx:\n            attn_score_scaling_factor = 1.0 / float(layer_id + 1)\n        else:\n            attn_score_scaling_factor = 1.0\n\n        # Attention\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id,\n                qkv,\n                self.num_heads,\n                sm_scale=attn_score_scaling_factor * (self.head_dim**-0.5),\n            ),\n            (b, s, h * d),\n        )\n        return self.c_proj(output)\n\n\nclass GPT2MLP(nn.Module):\n    def __init__(self, config: GPT2Config):\n        embed_dim = config.n_embd\n        if config.n_inner % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.n_inner} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        intermediate_size = config.n_inner // config.tensor_parallel_shards\n        self.c_fc = nn.Linear(embed_dim, intermediate_size)\n        self.c_proj = nn.Linear(intermediate_size, embed_dim)\n\n    def forward(self, hidden_states: Tensor):\n        hidden_states = self.c_fc(hidden_states)\n        hidden_states = op.gelu(hidden_states, approximate=\"tanh\")\n        hidden_states = self.c_proj(hidden_states)\n        return hidden_states\n\n\nclass GPT2Block(nn.Module):\n    def __init__(self, config: GPT2Config):\n        hidden_size = config.n_embd\n        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)\n        self.attn = GPT2Attention(config)\n        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)\n        self.mlp = GPT2MLP(config)\n\n        def _set_tp():\n            def _set(param, hint):\n                param.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = k = v = self.attn.num_heads * hd\n            _set(\n                self.attn.c_attn.weight,\n                tp.ShardSingleDim(\"_shard_qkv_weight\", dim=0, segs=[q, k, v]),\n            )\n            _set(\n                self.attn.c_attn.bias,\n                tp.ShardSingleDim(\"_shard_qkv_bias\", dim=0, segs=[q, k, v]),\n            )\n            _set(self.attn.c_proj.weight, tp.ShardSingleDim(\"_shard_attn_c_proj\", dim=1))\n            _set(\n                self.mlp.c_fc.weight,\n                tp.ShardSingleDim(\"_shard_c_fc_weight\", dim=0),\n            )\n            _set(self.mlp.c_fc.bias, tp.ShardSingleDim(\"_shard_c_fc_bias\", dim=0))\n            _set(self.mlp.c_proj.weight, tp.ShardSingleDim(\"_shard_mlp_c_proj\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        with (\n            tp.shard_bias(self.attn.c_proj, self.tensor_parallel_shards),\n            tp.shard_bias(self.mlp.c_proj, self.tensor_parallel_shards),\n        ):\n            hidden_states = self._apply_residual(\n                self.attn(self.ln_1(hidden_states), paged_kv_cache, layer_id),\n                hidden_states,\n            )\n            hidden_states = self._apply_residual(self.mlp(self.ln_2(hidden_states)), hidden_states)\n\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out + residual / self.tensor_parallel_shards, \"sum\")\n        return out + residual\n\n\nclass GPT2Model(nn.Module):\n    def __init__(self, config: GPT2Config):\n        assert config.n_embd % config.n_head == 0\n        self.wte = nn.Embedding(\"vocab_size\", config.n_embd)\n        self.wpe = nn.Embedding(config.context_window_size, config.n_embd)\n        self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.n_layer)])\n        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)\n\n    def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        # Position Embeddings\n        # Generate np.arange(offset, offset+seq_len)\n        # shape[1] indicates the total query length in the batch\n        input_positions = paged_kv_cache.get_query_positions(inputs.shape[1])\n        pos_embd = self.wpe(input_positions)\n\n        # Pass through GPT2Block\n        hidden_states = inputs + pos_embd\n        for layer_id, layer in enumerate(self.h):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.ln_f(hidden_states)\n        return hidden_states\n\n\nclass GPT2LMHeadModel(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: GPT2Config):\n        self.transformer = GPT2Model(config)\n        self.lm_head = nn.Linear(config.n_embd, \"vocab_size\", bias=False)\n        self.n_layer = config.n_layer\n        self.n_embed = config.n_embd\n        self.n_head = config.n_head\n        self.head_dim = config.head_dim\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.transformer(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.transformer.wte(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.transformer(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.transformer(input_embed, paged_kv_cache)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.n_layer,\n            num_attention_heads=self.n_head // self.tensor_parallel_shards,\n            num_key_value_heads=self.n_head // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NONE,\n            rope_scale=-1,\n            rope_theta=-1,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.n_embed], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.n_embed], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.n_embed], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.n_embed], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.n_embed], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/gpt_bigcode/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/gpt_bigcode/gpt_bigcode_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's GPTBigCode parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\n\nfrom .gpt_bigcode_model import GPTBigCodeForCausalLM\n\nhuggingface = make_standard_hf_loader(\n    model_cls=GPTBigCodeForCausalLM,\n    include_qkv=False,\n    include_gate_up=False,\n)\n"
  },
  {
    "path": "python/mlc_llm/model/gpt_bigcode/gpt_bigcode_model.py",
    "content": "\"\"\"\nImplementation for GPTBigCode architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass GPTBigCodeConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the GPTBigCode model.\"\"\"\n\n    n_embd: int\n    n_inner: int\n    n_head: int\n    n_layer: int\n    n_positions: int\n    layer_norm_epsilon: float\n    vocab_size: int\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    max_batch_size: int = 1\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.context_window_size == 0:\n            if self.n_positions > 0:\n                self.context_window_size = self.n_positions\n                logger.info(\n                    \"%s not found in config.json. Falling back to %s (%d)\",\n                    bold(\"context_window_size\"),\n                    bold(\"n_positions\"),\n                    self.context_window_size,\n                )\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass GPTBigCodeMLP(nn.Module):\n    def __init__(self, config: GPTBigCodeConfig):\n        super().__init__()\n        self.n_inner = config.n_inner // config.tensor_parallel_shards\n        self.c_fc = nn.Linear(in_features=config.n_embd, out_features=self.n_inner, bias=True)\n        self.c_proj = nn.Linear(in_features=self.n_inner, out_features=config.n_embd, bias=True)\n\n    def forward(self, x: Tensor):\n        hidden_states = self.c_fc(x)\n        hidden_states = op.gelu(hidden_states)\n        hidden_states = self.c_proj(hidden_states)\n        return hidden_states\n\n\nclass GPTBigCodeAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: GPTBigCodeConfig):\n        self.n_embd = config.n_embd\n        self.head_dim = config.n_embd // config.n_head\n        self.num_q_heads = config.n_head // config.tensor_parallel_shards\n        self.num_kv_heads = 1\n        assert (\n            config.tensor_parallel_shards == 1\n        ), \"GPT bigcode only support tensor parallel shards = 1\"\n        self.c_attn = nn.Linear(\n            in_features=self.n_embd,\n            out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim,\n            bias=True,\n        )\n        self.c_proj = nn.Linear(\n            in_features=self.num_q_heads * self.head_dim,\n            out_features=config.n_embd,\n            bias=True,\n        )\n\n    def forward(\n        self,\n        hidden_states: Tensor,\n        paged_kv_cache: PagedKVCache,\n        layer_id: int,\n    ):\n        d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads\n        b, s, _ = hidden_states.shape\n\n        # QKV Projection\n        qkv = self.c_attn(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        # Attention\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, h_q, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        return self.c_proj(output)\n\n\nclass GPTBigCodeBlock(nn.Module):\n    def __init__(self, config: GPTBigCodeConfig):\n        self.attn = GPTBigCodeAttention(config)\n        self.mlp = GPTBigCodeMLP(config)\n        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)\n        self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.weight.attrs[\"shard_strategy\"] = hint\n\n            hd = config.n_embd // config.n_head\n            q = config.n_head * hd\n            k = 1 * hd\n            v = 1 * hd\n            _set(\n                self.attn.c_attn,\n                tp.ShardSingleDim(\"_shard_c_attn\", dim=0, segs=[q, k, v]),\n            )\n            _set(self.attn.c_proj, tp.ShardSingleDim(\"_shard_c_proj\", dim=1))\n            _set(self.mlp.c_fc, tp.ShardSingleDim(\"_shard_mlp_c_fc\", dim=0))\n            _set(self.mlp.c_proj, tp.ShardSingleDim(\"_shard_mlp_c_proj\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.attn(self.ln_1(hidden_states), paged_kv_cache, layer_id)\n        hidden_states = out + hidden_states\n        out = self.mlp(self.ln_2(hidden_states))\n        hidden_states = out + hidden_states\n        return hidden_states\n\n\nclass GPTBigCodeModel(nn.Module):\n    def __init__(self, config: GPTBigCodeConfig):\n        assert config.n_embd % config.n_head == 0\n        self.wte = nn.Embedding(\"vocab_size\", config.n_embd)\n        self.wpe = nn.Embedding(config.n_positions, config.n_embd)\n        self.h = nn.ModuleList([GPTBigCodeBlock(config) for _ in range(config.n_layer)])\n        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)\n\n    def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        # Position Embeddings\n        # shape[1] indicates the total query length in the batch\n        input_positions = paged_kv_cache.get_query_positions(input_embed.shape[1])\n        pos_embd = self.wpe(input_positions)\n\n        # apply position embeddings\n        hidden_states = input_embed + pos_embd\n        for layer_id, layer in enumerate(self.h):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.ln_f(hidden_states)\n\n        return hidden_states\n\n\nclass GPTBigCodeForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: GPTBigCodeConfig):\n        self.transformer = GPTBigCodeModel(config)\n        self.lm_head = nn.Linear(config.n_embd, \"vocab_size\", bias=False)\n        self.n_layer = config.n_layer\n        self.n_embd = config.n_embd\n        self.num_q_heads = config.n_head // config.tensor_parallel_shards\n        self.num_kv_heads = 1\n        self.head_dim = config.n_embd // config.n_head\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embed: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.transformer(input_embed, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.transformer.wte(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.transformer(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.transformer(input_embed, paged_kv_cache)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.n_layer,\n            num_attention_heads=self.num_q_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_kv_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NONE,\n            rope_scale=-1,\n            rope_theta=-1,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.n_embd], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.n_embd], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.n_embd], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.n_embd], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.n_embd], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/gpt_j/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/gpt_j/gpt_j_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's GPTJ parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\n\nfrom .gpt_j_model import GPTJForCausalLM\n\nhuggingface = make_standard_hf_loader(\n    model_cls=GPTJForCausalLM,\n    layer_prefix=\"transformer.h\",\n    qkv_target_name=\"c_attn\",\n    include_gate_up=False,\n    num_layers_getter=lambda config: config.n_layer,  # type: ignore[attr-defined]\n)\n"
  },
  {
    "path": "python/mlc_llm/model/gpt_j/gpt_j_model.py",
    "content": "\"\"\"\nImplementation for GPTJ architecture.\nTODO: add docstring\n\"\"\"\n\nimport dataclasses\nfrom functools import partial\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass GPTJConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the GPTJ model.\"\"\"\n\n    vocab_size: int\n    n_embd: int\n    n_layer: int\n    n_head: int\n    layer_norm_epsilon: int\n    rotary_dim: int\n    activation_function: str\n    n_inner: int = -1\n    rope_scaling: Optional[Dict[str, Any]] = None\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    max_batch_size: int = 1\n    head_dim: int = 0\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"n_positions\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.head_dim == 0:\n            self.head_dim = self.n_embd // self.n_head\n        assert self.head_dim * self.n_head == self.n_embd\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass GPTJAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: GPTJConfig):\n        self.embed_dim = config.n_embd\n        self.num_heads = config.n_head // config.tensor_parallel_shards\n        self.head_dim = config.head_dim\n        self.max_position_embeddings = config.context_window_size\n        self.rope_theta = 10000\n        self.rotary_dim = config.rotary_dim\n        self.c_attn = nn.Linear(\n            in_features=self.embed_dim,\n            out_features=3 * self.embed_dim,\n            bias=False,\n        )\n        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)\n\n    def forward(  # pylint: disable=too-many-locals\n        self,\n        hidden_states: Tensor,\n        paged_kv_cache: PagedKVCache,\n        layer_id: int,\n    ):\n        d, h = self.head_dim, self.num_heads\n        b, s, _ = hidden_states.shape\n\n        qkv = self.c_attn(hidden_states)\n        qkv = op.reshape(qkv, (b, s, 3 * h, d))\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h * d),\n        )\n        return self.out_proj(output)\n\n\nACT2FN = {\n    \"gelu\": partial(nn.gelu, approximate=False),\n    \"relu\": nn.relu,\n    \"silu\": nn.silu,\n    \"swish\": nn.silu,\n    \"gelu_new\": partial(nn.gelu, approximate=True),\n}\n\n\nclass GPTJMLP(nn.Module):\n    def __init__(self, config: GPTJConfig):  # in MLP: intermediate_size= 4 * embed_dim\n        embed_dim = config.n_embd\n        inner_dim = 4 * config.n_embd if config.n_inner is None else config.n_inner\n        self.fc_in = nn.Linear(embed_dim, inner_dim, bias=True)\n        self.fc_out = nn.Linear(inner_dim, embed_dim, bias=True)\n        self.act_fn = ACT2FN[config.activation_function]\n\n    def forward(self, hidden_states: Tensor):\n        hidden_states = self.fc_in(hidden_states)\n        hidden_states = self.act_fn(hidden_states)\n        hidden_states = self.fc_out(hidden_states)\n        return hidden_states\n\n\nclass GPTJBlock(nn.Module):\n    def __init__(self, config: GPTJConfig):\n        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)\n        self.attn = GPTJAttention(config)\n        self.mlp = GPTJMLP(config)\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.attn.num_heads * hd\n            k = self.attn.num_heads * hd\n            v = self.attn.num_heads * hd\n            _set(\n                self.attn.c_attn.weight,\n                tp.ShardSingleDim(\"_shard_qkv_weight\", dim=0, segs=[q, k, v]),\n            )\n            _set(self.attn.out_proj.weight, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.fc_in.weight,\n                tp.ShardSingleDim(\"_shard_c_fc_weight\", dim=0),\n            )\n            _set(self.mlp.fc_out.weight, tp.ShardSingleDim(\"_shard_mlp_c_proj\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        residual = hidden_states\n        hidden_states = self.ln_1(hidden_states)\n        attn_output = self.attn(hidden_states, paged_kv_cache, layer_id)\n        feed_forward_hidden_states = self.mlp(hidden_states)\n        hidden_states = self._apply_residual(attn_output + feed_forward_hidden_states, residual)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass GPTJModel(nn.Module):\n    def __init__(self, config: GPTJConfig):\n        self.embed_dim = config.n_embd\n        self.vocab_size = config.vocab_size\n        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)\n        self.h = nn.ModuleList([GPTJBlock(config) for _ in range(config.n_layer)])\n        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)\n\n    def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = inputs\n        for layer_id, layer in enumerate(self.h):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.ln_f(hidden_states)\n        return hidden_states\n\n\nclass GPTJForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: GPTJConfig):\n        self.transformer = GPTJModel(config)\n        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, dtype=\"float32\")\n        self.dtype = \"float32\"\n        self.hidden_size = config.n_embd\n        self.num_hidden_layers = config.n_layer\n        self.intermediate_size = 4 * config.n_embd if config.n_inner is None else config.n_inner\n        self.num_attention_heads = config.n_head\n        self.rope_theta = 10000\n        self.rope_scaling = config.rope_scaling\n        self.vocab_size = config.vocab_size\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.head_dim = config.n_embd // config.n_head\n        self.rotary_dim = config.rotary_dim\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.transformer(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.transformer.wte(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.transformer(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.transformer(input_embed, paged_kv_cache)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            rotary_dim=self.rotary_dim,\n            rope_scaling=self.rope_scaling,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/gpt_neox/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/gpt_neox/gpt_neox_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's GPTNeoX parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.quantization import Quantization\n\nfrom .gpt_neox_model import GPTNeoXConfig, GPTNeoXForCausalLM\n\n\ndef huggingface(model_config: GPTNeoXConfig, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : GPTNeoXConfig\n        The configuration of the GPTNeoX model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = GPTNeoXForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    for i in range(model_config.num_hidden_layers):\n        # inv_freq/masked_bias/bias is not used in the model\n        attn = f\"gpt_neox.layers.{i}.attention\"\n        mapping.add_unused(f\"{attn}.rotary_emb.inv_freq\")\n        mapping.add_unused(f\"{attn}.masked_bias\")\n        mapping.add_unused(f\"{attn}.bias\")\n\n        # change the layout of query_key_value\n        def transform_qkv_layout(w, dtype):  # pylint: disable=invalid-name\n            num_attention_heads = model_config.num_attention_heads\n            head_dim = model_config.head_dim\n\n            org_shape = w.shape\n            w = np.reshape(w, [num_attention_heads, 3 * head_dim, -1])\n            qkv = np.split(w, indices_or_sections=3, axis=1)\n            w = np.concatenate(qkv, axis=0)\n            w = np.reshape(w, org_shape)\n            return w.astype(dtype)\n\n        qkv_proj = f\"{attn}.query_key_value\"\n        for param_name in [\"weight\", \"bias\"]:\n            mlc_name = f\"{qkv_proj}.{param_name}\"\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(\n                    transform_qkv_layout,\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            if \".dense_h_to_4h.bias\" in mlc_name or \".dense_4h_to_h.bias\" in mlc_name:\n                param_dtype = model_config.ffn_out_dtype\n            else:\n                param_dtype = mlc_param.dtype\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=param_dtype,\n                ),\n            )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/gpt_neox/gpt_neox_model.py",
    "content": "\"\"\"\nImplementation for GPTNeoX architecture.\n\"\"\"\n\nimport dataclasses\nimport logging\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass GPTNeoXConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the GPTNeoX model.\"\"\"\n\n    use_parallel_residual: bool\n    hidden_size: int\n    intermediate_size: int\n    num_attention_heads: int\n    num_hidden_layers: int\n    layer_norm_eps: float\n    vocab_size: int\n    rotary_pct: float\n    position_embedding_base: int = 0\n    context_window_size: int = 0\n    head_dim: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    ffn_out_dtype: str = \"float32\"\n    max_batch_size: int = 1\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.position_embedding_base == 0:\n            if \"rope_theta\" in self.kwargs:\n                self.position_embedding_base = self.kwargs.pop(\"rope_theta\")\n            else:\n                self.position_embedding_base = 10000\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.head_dim * self.num_attention_heads == self.hidden_size\n\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass GPTNeoXAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Multi-headed attention from 'Attention Is All You Need' paper\"\"\"\n\n    def __init__(self, config: GPTNeoXConfig):\n        self.rope_theta = config.position_embedding_base\n        self.hidden_size = config.hidden_size\n        if config.num_attention_heads % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.num_attention_heads} attention heads \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.num_attention_heads = config.num_attention_heads // config.tensor_parallel_shards\n        self.head_dim = config.head_dim\n        self.query_key_value = nn.Linear(\n            in_features=self.hidden_size,\n            out_features=3 * self.num_attention_heads * self.head_dim,\n            bias=True,\n        )\n        self.dense = nn.Linear(\n            self.num_attention_heads * self.head_dim, self.hidden_size, bias=True\n        )\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        # hidden_states: [batch_size, seq_len, hidden_size]\n        batch_size, seq_len, _ = hidden_states.shape\n\n        # q/k/v states: [batch_size, seq_len, hidden_size]\n        qkv = self.query_key_value(hidden_states)\n        qkv = op.reshape(qkv, (batch_size, seq_len, 3 * self.num_attention_heads, self.head_dim))\n\n        # Attention\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_attention_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (batch_size, seq_len, self.head_dim * self.num_attention_heads),\n        )\n        attn_output = self.dense(output)\n        return attn_output\n\n\nclass GPTNeoXMLP(nn.Module):\n    def __init__(self, config: GPTNeoXConfig):\n        super().__init__()\n        out_dtype = config.ffn_out_dtype\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.dense_h_to_4h = nn.Linear(\n            config.hidden_size,\n            self.intermediate_size,\n            out_dtype=out_dtype,\n        )\n        self.dense_4h_to_h = nn.Linear(\n            self.intermediate_size,\n            config.hidden_size,\n            out_dtype=out_dtype,\n        )\n\n    def forward(self, hidden_states: Tensor):\n        dtype = hidden_states.dtype\n        if hidden_states.dtype != dtype:\n            hidden_states = hidden_states.astype(dtype)\n        hidden_states = self.dense_h_to_4h(hidden_states)\n        hidden_states = op.gelu(hidden_states)\n        if hidden_states.dtype != dtype:\n            hidden_states = hidden_states.astype(dtype)\n        hidden_states = self.dense_4h_to_h(hidden_states)\n        if hidden_states.dtype != dtype:\n            hidden_states = hidden_states.astype(dtype)\n        return hidden_states\n\n\nclass GPTNeoXLayer(nn.Module):\n    def __init__(self, config: GPTNeoXConfig):\n        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n        self.attention = GPTNeoXAttention(config)\n        self.mlp = GPTNeoXMLP(config)\n        self.use_parallel_residual = config.use_parallel_residual\n\n        def _set_tp():\n            def _set(param, hint):\n                param.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = k = v = self.attention.num_attention_heads * hd\n            _set(\n                self.attention.query_key_value.weight,\n                tp.ShardSingleDim(\"_shard_qkv_weight\", dim=0, segs=[q, k, v]),\n            )\n            _set(\n                self.attention.query_key_value.bias,\n                tp.ShardSingleDim(\"_shard_qkv_bias\", dim=0, segs=[q, k, v]),\n            )\n            _set(self.attention.dense.weight, tp.ShardSingleDim(\"_shard_dense\", dim=1))\n            _set(\n                self.mlp.dense_h_to_4h.weight,\n                tp.ShardSingleDim(\"_shard_dense_h_to_4h_weight\", dim=0),\n            )\n            _set(\n                self.mlp.dense_h_to_4h.bias,\n                tp.ShardSingleDim(\"_shard_dense_h_to_4h_bias\", dim=0),\n            )\n            _set(\n                self.mlp.dense_4h_to_h.weight,\n                tp.ShardSingleDim(\"_shard_dense_4h_to_h\", dim=1),\n            )\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        dtype = hidden_states.dtype\n        attn_input = self.input_layernorm(hidden_states)\n        with tp.shard_bias(self.attention.dense, self.tensor_parallel_shards):\n            attn_output = self.attention(\n                attn_input,\n                paged_kv_cache,\n                layer_id,\n            )\n        if self.use_parallel_residual:\n            mlp_input = self.post_attention_layernorm(hidden_states)\n            mlp_output = self.mlp(mlp_input)\n            hidden_states = mlp_output + attn_output + hidden_states\n        else:\n            attn_output = self._apply_residual(attn_output, hidden_states)\n            mlp_input = self.post_attention_layernorm(attn_output)\n            with tp.shard_bias(self.mlp.dense_4h_to_h, self.tensor_parallel_shards):\n                mlp_output = self.mlp(mlp_input)\n            hidden_states = self._apply_residual(mlp_output.astype(dtype), attn_output)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out + residual / self.tensor_parallel_shards, \"sum\")\n        return out + residual\n\n\nclass GPTNeoXModel(nn.Module):\n    def __init__(self, config: GPTNeoXConfig):\n        self.embed_in = nn.Embedding(num=\"vocab_size\", dim=config.hidden_size)\n        self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])\n        self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n\n    def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = inputs\n\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.final_layer_norm(hidden_states)\n        return hidden_states\n\n\nclass GPTNeoXForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: GPTNeoXConfig):\n        self.gpt_neox = GPTNeoXModel(config)\n        self.embed_out = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=\"vocab_size\",\n            bias=False,\n            dtype=\"float32\",\n        )\n        self.num_hidden_layers = config.num_hidden_layers\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.num_attention_heads\n        self.head_dim = config.head_dim\n        self.vocab_size = config.vocab_size\n        self.rope_theta = config.position_embedding_base\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.dtype = \"float32\"\n        self.rotary_pct = config.rotary_pct\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.gpt_neox(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.embed_out(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.gpt_neox.embed_in(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.gpt_neox(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.embed_out(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.gpt_neox(input_embed, paged_kv_cache)\n        logits = self.embed_out(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n            rotary_dim=int(self.head_dim * self.rotary_pct),\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/internlm/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/internlm/internlm_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's InternLM parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\n\nfrom .internlm_model import InternLMForCausalLM\n\nhuggingface = make_standard_hf_loader(\n    model_cls=InternLMForCausalLM,\n    qkv_target_name=\"wqkv_pack\",\n    add_qkv_bias=True,\n    qkv_bias_optional=True,\n    add_unused=[\"rotary_emb.inv_freq\"],\n)\n"
  },
  {
    "path": "python/mlc_llm/model/internlm/internlm_model.py",
    "content": "\"\"\"\nImplementation for InternLM architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass InternLMConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the InternLM model.\"\"\"\n\n    vocab_size: int\n    hidden_size: int\n    num_hidden_layers: int\n    num_attention_heads: int\n    rms_norm_eps: float\n    intermediate_size: int\n    bias: bool\n    use_cache: bool\n    pad_token_id: int\n    bos_token_id: int\n    eos_token_id: int\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    max_batch_size: int = 1\n    head_dim: int = 0\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.head_dim * self.num_attention_heads == self.hidden_size\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass InternLMAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: InternLMConfig):\n        self.hidden_size = config.hidden_size\n        if config.num_attention_heads % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.num_attention_heads} attention heads \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.num_heads = config.num_attention_heads // config.tensor_parallel_shards\n        self.head_dim = config.head_dim\n        self.max_position_embeddings = config.context_window_size\n\n        self.wqkv_pack = nn.Linear(\n            self.hidden_size, 3 * self.num_heads * self.head_dim, bias=config.bias\n        )\n        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h = self.head_dim, self.num_heads\n        b, s, _ = hidden_states.shape\n        qkv = self.wqkv_pack(hidden_states)\n        qkv = op.reshape(qkv, (b, s, 3 * h, d))\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h * d),\n        )\n        attn_output = self.o_proj(output)\n        return attn_output\n\n\nclass InternLMMLP(nn.Module):\n    def __init__(self, config: InternLMConfig):\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n\n        self.gate_up_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=2 * self.intermediate_size,\n            bias=False,\n        )\n        self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)\n\n    def forward(self, x):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(op.silu(x1) * x2)\n\n\nclass InternLMDecoderLayer(nn.Module):\n    def __init__(self, config: InternLMConfig):\n        self.self_attn = InternLMAttention(config)\n        self.mlp = InternLMMLP(config)\n        self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n        self.post_attention_layernorm = nn.RMSNorm(\n            config.hidden_size, -1, config.rms_norm_eps, bias=False\n        )\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_heads * hd\n            k = self.self_attn.num_heads * hd\n            v = self.self_attn.num_heads * hd\n            i = self.mlp.intermediate_size\n            _set(\n                self.self_attn.wqkv_pack.weight,\n                tp.ShardSingleDim(\"_shard_qkv_weight\", dim=0, segs=[q, k, v]),\n            )\n            if config.bias:\n                _set(\n                    self.self_attn.wqkv_pack.bias,\n                    tp.ShardSingleDim(\"_shard_qkv_bias\", dim=0, segs=[q, k, v]),\n                )\n            _set(\n                self.self_attn.o_proj.weight,\n                tp.ShardSingleDim(\"_shard_o_weight\", dim=1),\n            )\n            if config.bias:\n                _set(\n                    self.self_attn.o_proj.bias,\n                    tp.ShardSingleDim(\"_shard_o_bias\", dim=0),\n                )\n            _set(\n                self.mlp.gate_up_proj.weight,\n                tp.ShardSingleDim(\"_shard_mlp_gate_up\", segs=[i, i], dim=0),\n            )\n            _set(\n                self.mlp.down_proj.weight,\n                tp.ShardSingleDim(\"_shard_mlp_down_proj\", dim=1),\n            )\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.mlp(self.post_attention_layernorm(hidden_states))\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass InternLMModel(nn.Module):\n    def __init__(self, config: InternLMConfig):\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)\n        self.layers = nn.ModuleList(\n            [InternLMDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n\n    def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = inputs\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass InternLMForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: InternLMConfig):\n        self.model = InternLMModel(config)\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.vocab_size = config.vocab_size\n        self.num_hidden_layers = config.num_hidden_layers\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.num_attention_heads\n        self.head_dim = config.head_dim\n        self.vocab_size = config.vocab_size\n        self.rope_theta = 10000\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/internlm2/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/internlm2/internlm2_loader.py",
    "content": "# pylint: disable=W0611\n\"\"\"\nThis file specifies how MLC's InternLM2 parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.quantization import Quantization\n\nfrom .internlm2_model import InternLM2Config, InternLM2ForCausalLM\n\n\ndef huggingface(model_config: InternLM2ForCausalLM, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : InternLM2Config\n        The configuration of the InternLM2 model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = InternLM2ForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    def _convert_wqkv_layout(wqkv, dtype):\n        config = model_config\n        kv_groups = config.num_attention_heads // config.num_key_value_heads\n        head_dim = config.hidden_size // config.num_attention_heads\n        wqkv = wqkv.reshape(-1, 2 + kv_groups, head_dim, wqkv.shape[-1])\n        wq, wk, wv = np.split(wqkv, [kv_groups, kv_groups + 1], axis=1)  # pylint: disable=W0632\n        wq = wq.reshape(-1, wq.shape[-1])\n        wk = wk.reshape(-1, wk.shape[-1])\n        wv = wv.reshape(-1, wv.shape[-1])\n        return np.concatenate([wq, wk, wv], axis=0).astype(dtype)\n\n    for i in range(model_config.num_hidden_layers):\n        # Add gates in MLP\n        mlp = f\"model.layers.{i}.feed_forward\"\n        mlc_name = f\"{mlp}.gate_up_proj.weight\"\n        mlc_param = named_parameters[mlc_name]\n        mapping.add_mapping(\n            mlc_name,\n            [\n                f\"{mlp}.w1.weight\",\n                f\"{mlp}.w3.weight\",\n            ],\n            functools.partial(\n                lambda w1, w3, dtype: np.concatenate([w1, w3], axis=0).astype(dtype),\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n        mlc_name = f\"model.layers.{i}.attention.wqkv.weight\"\n        mlc_param = named_parameters[mlc_name]\n        mapping.add_mapping(\n            mlc_name,\n            [mlc_name],\n            functools.partial(\n                _convert_wqkv_layout,\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/internlm2/internlm2_model.py",
    "content": "\"\"\"\nImplementation for InternLM2 architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass InternLM2Config(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the InternLM2 model.\"\"\"\n\n    vocab_size: int\n    hidden_size: int\n    num_hidden_layers: int\n    num_attention_heads: int\n    num_key_value_heads: int\n    rms_norm_eps: float\n    intermediate_size: int\n    bias: bool\n    use_cache: bool\n    rope_theta: int\n    pad_token_id: int\n    bos_token_id: int\n    eos_token_id: int\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    max_batch_size: int = 1\n    head_dim: int = 0\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.head_dim * self.num_attention_heads == self.hidden_size\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 2048),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 2048)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 2048),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 2048)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass InternLM2Attention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: InternLM2Config):\n        if config.num_attention_heads % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.num_attention_heads} attention heads \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.hidden_size = config.hidden_size\n        self.rope_theta = config.rope_theta\n        self.num_heads = config.num_attention_heads // config.tensor_parallel_shards\n        self.head_dim = config.head_dim\n        self.num_key_value_heads = config.num_key_value_heads // config.tensor_parallel_shards\n        self.max_position_embeddings = config.context_window_size\n\n        self.wqkv = nn.Linear(\n            self.hidden_size,\n            (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,\n            bias=config.bias,\n        )\n        self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q, h_kv = self.head_dim, self.num_heads, self.num_key_value_heads\n        b, s, _ = hidden_states.shape\n        qkv = self.wqkv(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        attn_output = self.wo(output)\n        return attn_output\n\n\nclass InternLM2MLP(nn.Module):\n    def __init__(self, config: InternLM2Config):\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.gate_up_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=2 * self.intermediate_size,\n            bias=False,\n        )\n        self.w2 = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)\n\n    def forward(self, x: Tensor):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.w2(op.silu(x1) * x2)\n\n\nclass InternLM2DecoderLayer(nn.Module):\n    def __init__(self, config: InternLM2Config):\n        self.attention = InternLM2Attention(config)\n        self.feed_forward = InternLM2MLP(config)\n        self.attention_norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n        self.ffn_norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.attention.num_heads * hd\n            k = self.attention.num_key_value_heads * hd\n            v = self.attention.num_key_value_heads * hd\n            i = self.feed_forward.intermediate_size\n            _set(\n                self.attention.wqkv.weight,\n                tp.ShardSingleDim(\"_shard_qkv_weight\", dim=0, segs=[q, k, v]),\n            )\n            if config.bias:\n                _set(\n                    self.attention.wqkv.bias,\n                    tp.ShardSingleDim(\"_shard_qkv_bias\", dim=0, segs=[q, k, v]),\n                )\n            _set(self.attention.wo.weight, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.feed_forward.gate_up_proj.weight,\n                tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n            )\n            _set(self.feed_forward.w2.weight, tp.ShardSingleDim(\"_shard_mlp_down\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        residual = hidden_states\n        hidden_states = self.attention_norm(hidden_states)\n        hidden_states = self.attention(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(hidden_states, residual=residual)\n        residual = hidden_states\n        hidden_states = self.ffn_norm(hidden_states)\n        hidden_states = self.feed_forward(hidden_states)\n        hidden_states = self._apply_residual(hidden_states, residual=residual)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass InternLM2Model(nn.Module):\n    def __init__(self, config: InternLM2Config):\n        self.padding_idx = config.pad_token_id\n        self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)\n        self.layers = nn.ModuleList(\n            [InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n\n    def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = inputs\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass InternLM2ForCausalLM(nn.Module):  # pylint: disable=R0902\n    def __init__(self, config: InternLM2Config):\n        self.model = InternLM2Model(config)\n        self.output = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.vocab_size = config.vocab_size\n        self.dtype = \"float32\"\n        self.num_hidden_layers = config.num_hidden_layers\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.head_dim = config.head_dim\n        self.rope_theta = config.rope_theta\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.output(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.tok_embeddings(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.output(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.output(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/llama/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/llama/llama_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Llama parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\nfrom mlc_llm.quantization import Quantization, make_awq_quant\n\nfrom .llama_model import LlamaConfig, LlamaForCausalLM\n\nawq_quant = make_awq_quant(LlamaForCausalLM)\n\n\nhuggingface = make_standard_hf_loader(\n    model_cls=LlamaForCausalLM,\n    add_unused=[\"rotary_emb.inv_freq\"],\n)\n\n\ndef awq(model_config: LlamaConfig, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of AWQ parameters.\n    Parameters\n    ----------\n    model_config : LlamaConfig\n        The configuration of the Llama model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to AWQ.\n    \"\"\"\n    model, _ = awq_quant(model_config, quantization)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),  # type: ignore[attr-defined]\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    for i in range(model_config.num_hidden_layers):\n        # Add QKV in self attention\n        attn = f\"model.layers.{i}.self_attn\"\n        for quantize_suffix in [\"qweight\", \"qzeros\", \"scales\"]:\n            mlc_name = f\"{attn}.qkv_proj.{quantize_suffix}\"\n            assert mlc_name in named_parameters\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{attn}.q_proj.{quantize_suffix}\",\n                    f\"{attn}.k_proj.{quantize_suffix}\",\n                    f\"{attn}.v_proj.{quantize_suffix}\",\n                ],\n                functools.partial(\n                    lambda q, k, v, dtype: np.concatenate(\n                        [q, k, v],\n                        axis=1,  # AWQ GEMM would transpose the weight\n                    ).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n        # Concat gate and up in MLP\n        mlp = f\"model.layers.{i}.mlp\"\n        for quantize_suffix in [\"qweight\", \"qzeros\", \"scales\"]:\n            mlc_name = f\"{mlp}.gate_up_proj.{quantize_suffix}\"\n            assert mlc_name in named_parameters\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{mlp}.gate_proj.{quantize_suffix}\",\n                    f\"{mlp}.up_proj.{quantize_suffix}\",\n                ],\n                functools.partial(\n                    lambda gate, up, dtype: np.concatenate(\n                        [gate, up],\n                        axis=1,  # AWQ GEMM would transpose the weight\n                    ).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n        # inv_freq is not used in the model\n        mapping.add_unused(f\"{attn}.rotary_emb.inv_freq\")\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype),\n            )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/llama/llama_model.py",
    "content": "\"\"\"\nImplementation for Llama2 architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass LlamaConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Llama model.\"\"\"\n\n    hidden_size: int\n    intermediate_size: int\n    num_attention_heads: int\n    num_hidden_layers: int\n    rms_norm_eps: float\n    vocab_size: int\n    tie_word_embeddings: bool = False\n    position_embedding_base: int = 0\n    rope_scaling: Optional[Dict[str, Any]] = None\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    num_key_value_heads: int = 0\n    head_dim: int = 0\n    tensor_parallel_shards: int = 1\n    pipeline_parallel_stages: int = 1\n    max_batch_size: int = 1\n    disaggregation: bool = False\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):  # pylint: disable=too-many-branches\n        if self.position_embedding_base == 0:\n            if \"rope_theta\" in self.kwargs:\n                self.position_embedding_base = self.kwargs.pop(\"rope_theta\")\n            else:\n                self.position_embedding_base = 10000\n        if self.rope_scaling is not None:\n            if \"rope_type\" not in self.rope_scaling:\n                self.rope_scaling = None\n            else:\n                assert (\n                    self.rope_scaling[\"rope_type\"] == \"llama3\"\n                ), f\"Unsupported RoPE scaling type {self.rope_scaling['rope_type']} for Llama\"\n\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if (\n            self.pipeline_parallel_stages <= 0\n            or self.pipeline_parallel_stages > self.num_hidden_layers\n        ):\n            raise ValueError(\n                f'Invalid \"pipeline_parallel_stages\" value ({self.pipeline_parallel_stages}). '\n            )\n        if self.num_key_value_heads == 0:\n            self.num_key_value_heads = self.num_attention_heads\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.num_attention_heads % self.num_key_value_heads == 0\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass LlamaFFN(nn.Module):\n    def __init__(self, config: LlamaConfig):\n        super().__init__()\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.gate_up_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=2 * self.intermediate_size,\n            bias=False,\n        )\n        self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)\n\n    def forward(self, x: Tensor):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(op.silu(x1) * x2)\n\n\nclass LlamaEmbedding(nn.Embedding):\n    \"\"\"The embedding module that can be shared with the final lm_head. From Qwen2Embedding.\"\"\"\n\n    def lm_head_forward(self, x: nn.Tensor):\n        \"\"\"The lm_head forwarding, which transposes the weight and multiplies\n        with the input tensor.\n        \"\"\"\n        weight = nn.op.permute_dims(self.weight)\n        return nn.op.matmul(x, weight, out_dtype=\"float32\")\n\n\nclass LlamaAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: LlamaConfig):\n        self.head_dim = config.head_dim\n        self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards\n        assert (\n            config.num_key_value_heads % config.tensor_parallel_shards == 0\n        ), f\"num_kv_heads({config.num_key_value_heads}) must be divisible by tensor_parallel_shards\"\n        assert (\n            config.num_key_value_heads >= config.tensor_parallel_shards\n        ), f\"Too large tensor_parallel_shards, must be smaller than {config.num_key_value_heads}\"\n        self.num_kv_heads = config.num_key_value_heads // config.tensor_parallel_shards\n        self.qkv_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim,\n            bias=False,\n        )\n        self.o_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False)\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads\n        b, s, _ = hidden_states.shape\n        # QKV Projection\n        qkv = self.qkv_proj(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        # Attention\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_q_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        return self.o_proj(output)\n\n\nclass LlamaDecoderLayer(nn.Module):\n    def __init__(self, config: LlamaConfig):\n        rms_norm_eps = config.rms_norm_eps\n        self.self_attn = LlamaAttention(config)\n        self.mlp = LlamaFFN(config)\n        self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)\n        self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.weight.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_q_heads * hd\n            k = self.self_attn.num_kv_heads * hd\n            v = self.self_attn.num_kv_heads * hd\n            i = self.mlp.intermediate_size\n            _set(\n                self.self_attn.qkv_proj,\n                tp.ShardSingleDim(\"_shard_qkv\", segs=[q, k, v], dim=0),\n            )\n            _set(self.self_attn.o_proj, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.gate_up_proj,\n                tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n            )\n            _set(self.mlp.down_proj, tp.ShardSingleDim(\"_shard_mlp_down\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.mlp(self.post_attention_layernorm(hidden_states))\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass LlamaModel(nn.Module):\n    def __init__(self, config: LlamaConfig):\n        assert config.hidden_size % config.num_attention_heads == 0\n        self.embed_tokens = LlamaEmbedding(\"vocab_size\", config.hidden_size)\n        self.layers = nn.ModuleList(\n            [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n        self.num_layers_per_stage = (\n            config.num_hidden_layers + config.pipeline_parallel_stages - 1\n        ) // config.pipeline_parallel_stages\n\n        # Compute pipeline layer partition.\n        layers_per_stage = (\n            config.num_hidden_layers + config.pipeline_parallel_stages - 1\n        ) // config.pipeline_parallel_stages\n        self.layer_partition = [\n            i * layers_per_stage for i in range(config.pipeline_parallel_stages)\n        ] + [config.num_hidden_layers]\n\n    def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = input_embed\n        for layer_id, layer in enumerate(self.layers):\n            if layer_id != 0 and layer_id in self.layer_partition:\n                hidden_states = op_ext.pipeline_stage_boundary(hidden_states)\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass LlamaForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: LlamaConfig):\n        self.model = LlamaModel(config)\n        self.tie_word_embeddings = config.tie_word_embeddings\n        if not config.tie_word_embeddings:\n            self.lm_head = nn.Linear(config.hidden_size, \"vocab_size\", bias=False)\n        self.num_hidden_layers = config.num_hidden_layers\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.head_dim = config.head_dim\n        self.hidden_size = config.hidden_size\n        self.vocab_size = config.vocab_size\n        self.rope_scaling = config.rope_scaling\n        self.rope_theta = config.position_embedding_base\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.disaggregation = config.disaggregation\n        self.dtype = \"float32\"\n\n        def _set_pp():\n            # hidden layers\n            for layer_id in range(config.num_hidden_layers):\n                stage = layer_id // (config.num_hidden_layers // config.pipeline_parallel_stages)\n                for _, param in self.model.layers[layer_id].named_parameters():\n                    param.attrs[\"pipeline_stages\"] = [stage]\n            # last stage\n            last_stage = config.pipeline_parallel_stages - 1\n            self.model.norm.weight.attrs[\"pipeline_stages\"] = [last_stage]\n            # embedding table and lm_head is required by all stages\n            all_stages = list(range(config.pipeline_parallel_stages))\n            self.model.embed_tokens.weight.attrs[\"pipeline_stages\"] = all_stages\n            if not config.tie_word_embeddings:\n                self.lm_head.weight.attrs[\"pipeline_stages\"] = all_stages\n\n        _set_pp()\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            if self.tensor_parallel_shards > 1:\n                logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        return self.get_logits(hidden_states)\n\n    def batch_forward_to_last_hidden_states(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        return hidden_states\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def get_logits(self, hidden_states: Tensor):\n        op_ext.configure()\n        if self.tie_word_embeddings:\n            logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor):\n        op_ext.configure()\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        hidden_states = op.take(hidden_states, logit_positions, axis=0)\n        return hidden_states\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.get_logits(hidden_states)\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.get_logits(hidden_states)\n        return logits, paged_kv_cache\n\n    def prefill_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def decode_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_prefill_to_last_hidden_states(\n        self, input_embeds: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def batch_decode_to_last_hidden_states(\n        self, input_embeds: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def batch_verify_to_last_hidden_states(\n        self, input_embeds: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            rope_scaling=self.rope_scaling,\n            layer_partition=self.model.layer_partition,\n            enable_disaggregation=self.disaggregation,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"get_logits\": {\n                \"hidden_states\": nn.spec.Tensor([\"seq_len\", self.hidden_size], self.dtype),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_select_last_hidden_states\": {\n                \"hidden_states\": nn.spec.Tensor([\"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill_to_last_hidden_states\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode_to_last_hidden_states\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill_to_last_hidden_states\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode_to_last_hidden_states\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify_to_last_hidden_states\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/llama4/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/llama4/llama4_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Llama parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.quantization import Quantization\n\nfrom .llama4_model import Llama4Config, Llama4ForCausalLM\n\n\ndef huggingface(model_config: Llama4Config, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : Llama4Config\n        The configuration of the Llama model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = Llama4ForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    for i in range(model_config.text_config.num_hidden_layers):\n        # Add shared expert weights\n        mlp = f\"model.layers.{i}.feed_forward.shared_expert\"\n        mlc_name = f\"{mlp}.gate_up_proj.weight\"\n        mlc_param = named_parameters[mlc_name]\n        mapping.add_mapping(\n            mlc_name,\n            [\n                f\"language_model.{mlp}.gate_proj.weight\",\n                f\"language_model.{mlp}.up_proj.weight\",\n            ],\n            functools.partial(\n                lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n        # Add router weights\n        mlp = f\"model.layers.{i}.feed_forward\"\n        mlc_name = f\"{mlp}.router.router.weight\"\n        hf_name = f\"language_model.{mlp}.router.weight\"\n        mlc_param = named_parameters[mlc_name]\n        mapping.add_mapping(\n            mlc_name,\n            [\n                hf_name,\n            ],\n            functools.partial(\n                lambda x, dtype: x.astype(dtype),\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n        # Add experts weights\n        mlp = f\"model.layers.{i}.feed_forward\"\n        hf_name = f\"language_model.{mlp}.experts.gate_up_proj\"\n        mlc_name = f\"{mlp}.experts.gate_up_proj\"\n        mlc_param = named_parameters[mlc_name]\n        mapping.add_mapping(\n            mlc_name,\n            [\n                hf_name,\n            ],\n            functools.partial(\n                lambda x, dtype: x.astype(dtype),\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n        mlp = f\"model.layers.{i}.feed_forward\"\n        mlc_name = f\"{mlp}.experts.down_proj\"\n        hf_name = f\"language_model.{mlp}.experts.down_proj\"\n\n        mlc_param = named_parameters[mlc_name]\n        mapping.add_mapping(\n            mlc_name,\n            [\n                hf_name,\n            ],\n            functools.partial(\n                lambda x, dtype: x.astype(dtype),\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [f\"language_model.{mlc_name}\"],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/llama4/llama4_model.py",
    "content": "\"\"\"\nImplementation for Llama4 architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nimport tvm\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\nfrom tvm.relax.frontend.nn.llm import position_embedding\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.model.qwen3.qwen3_model import ACT2FN\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass Llama4TextConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Text portion of the Llama model.\"\"\"\n\n    hidden_size: int\n    intermediate_size: int\n    num_attention_heads: int\n    num_hidden_layers: int\n    rms_norm_eps: float\n    rope_theta: float\n    use_qk_norm: bool\n    interleave_moe_layer_step: int\n    num_experts_per_tok: int\n    num_local_experts: int\n    hidden_act: str\n    tie_word_embeddings: bool = False\n    position_embedding_base: int = 0\n    rope_scaling: Optional[Dict[str, Any]] = None\n    num_key_value_heads: int = 0\n    head_dim: int = 0\n    attn_scale: float = 0.1\n    floor_scale: int = 8192\n    vocab_size: int = 202048\n    attention_bias: bool = False\n    attn_temperature_tuning: bool = True\n    no_rope_layers: list[int] = None\n    no_rope_layer_interval: int = 4\n    moe_layers: list[int] = None\n\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):  # pylint: disable=too-many-branches\n        if self.position_embedding_base == 0:\n            if \"rope_theta\" in self.kwargs:\n                self.position_embedding_base = self.kwargs.pop(\"rope_theta\")\n            else:\n                self.position_embedding_base = 10000\n        if self.rope_scaling is not None:\n            if \"rope_type\" not in self.rope_scaling:\n                self.rope_scaling = None\n            else:\n                assert (\n                    self.rope_scaling[\"rope_type\"] == \"llama3\"\n                ), f\"Unsupported RoPE scaling type {self.rope_scaling['rope_type']} for Llama\"\n\n        # Define which layers to avoid RoPE\n        if self.no_rope_layers == []:\n            self.no_rope_layers = None\n\n        default_no_rope_layers = [\n            int((layer_idx + 1) % self.no_rope_layer_interval != 0)\n            for layer_idx in range(self.num_hidden_layers)\n        ]\n\n        self.no_rope_layers = self.no_rope_layers if self.no_rope_layers else default_no_rope_layers\n\n        # Define which layers to apply MoE\n        self.moe_layers = (\n            self.moe_layers\n            if self.moe_layers is not None\n            else list(\n                range(\n                    self.interleave_moe_layer_step - 1,\n                    self.num_hidden_layers,\n                    self.interleave_moe_layer_step,\n                )\n            )\n        )\n\n\n@dataclasses.dataclass\nclass Llama4Config(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Llama model.\"\"\"\n\n    text_config: Llama4TextConfig\n    tensor_parallel_shards: int = 1\n    context_window_size: int = 0\n    pipeline_parallel_stages: int = 1\n    prefill_chunk_size: int = 0\n    max_batch_size: int = 1\n    disaggregation: bool = False\n    max_position_embeddings = 4096 * 32\n    vocab_size: int = 202048\n\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self) -> None:\n        text_config_dict: Dict[str, Any]\n        if isinstance(self.text_config, ConfigBase):\n            text_config_dict = dataclasses.asdict(self.text_config)\n        else:\n            text_config_dict = dict(self.text_config)\n\n        for k, v in text_config_dict.pop(\"kwargs\", {}).items():\n            text_config_dict[k] = v\n\n        self.text_config = Llama4TextConfig.from_dict(text_config_dict)  # type: ignore\n\n        if self.context_window_size == 0:\n            # Fall back to max_position_embeddings\n\n            self.context_window_size = self.max_position_embeddings\n            logger.info(\n                \"%s not found in config.json. Falling back to %s (%d)\",\n                bold(\"context_window_size\"),\n                bold(\"max_position_embeddings\"),\n                self.context_window_size,\n            )\n\n        if self.text_config.num_key_value_heads == 0:\n            self.text_config.num_key_value_heads = self.text_config.num_attention_heads\n        if self.text_config.head_dim == 0:\n            self.text_config.head_dim = (\n                self.text_config.hidden_size // self.text_config.num_attention_heads\n            )\n        assert self.text_config.num_attention_heads % self.text_config.num_key_value_heads == 0\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass Llama4TextMLP(nn.Module):\n    def __init__(self, config: Llama4Config):\n        super().__init__()\n        if config.text_config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.text_config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = (\n            config.text_config.intermediate_size // config.tensor_parallel_shards\n        )\n        self.gate_up_proj = nn.Linear(\n            in_features=config.text_config.hidden_size,\n            out_features=2 * self.intermediate_size,\n            bias=False,\n        )\n        self.down_proj = nn.Linear(\n            self.intermediate_size, config.text_config.hidden_size, bias=False\n        )\n\n    def forward(self, x: Tensor):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        inter_out = op.silu(x1) * x2\n\n        return self.down_proj(inter_out)\n\n\nclass LlamaEmbedding(nn.Embedding):\n    \"\"\"The embedding module that can be shared with the final lm_head. From Qwen2Embedding.\"\"\"\n\n    def lm_head_forward(self, x: nn.Tensor):\n        \"\"\"The lm_head forwarding, which transposes the weight and multiplies\n        with the input tensor.\n        \"\"\"\n        weight = nn.op.permute_dims(self.weight)\n        return nn.op.matmul(x, weight, out_dtype=\"float32\")\n\n\nclass Llama4TextL2Norm(nn.Module):\n    def __init__(self, eps, hidden_size):\n        self.eps = eps\n        self.hidden_size = hidden_size\n\n    def forward(self, x):\n        weight = op.ones((self.hidden_size,), dtype=x.dtype)\n        return op.rms_norm(x, weight=weight, axes=[-1], epsilon=self.eps)\n\n\nclass Llama4TextAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: Llama4Config, layer_idx):\n        self.head_dim = config.text_config.head_dim\n        self.attn_scale = config.text_config.attn_scale\n        self.floor_scale = config.text_config.floor_scale\n        self.num_attention_heads = config.text_config.num_attention_heads\n        self.num_kv_heads = config.text_config.num_key_value_heads\n        self.num_q_heads = config.text_config.num_attention_heads // config.tensor_parallel_shards\n        assert config.text_config.num_key_value_heads % config.tensor_parallel_shards == 0, (\n            f\"num_kv_heads({config.text_config.num_key_value_heads}) must be divisible by \"\n            f\"tensor_parallel_shards\"\n        )\n\n        assert config.text_config.num_key_value_heads >= config.tensor_parallel_shards, (\n            f\"Too large tensor_parallel_shards, must be smaller than \"\n            f\"{config.text_config.num_key_value_heads}\"\n        )\n        self.num_kv_heads = config.text_config.num_key_value_heads // config.tensor_parallel_shards\n        self.q_proj = nn.Linear(\n            config.text_config.hidden_size,\n            self.num_q_heads * self.head_dim,\n            bias=config.text_config.attention_bias,\n        )\n        self.k_proj = nn.Linear(\n            config.text_config.hidden_size,\n            self.num_kv_heads * self.head_dim,\n            bias=config.text_config.attention_bias,\n        )\n        self.v_proj = nn.Linear(\n            config.text_config.hidden_size,\n            self.num_kv_heads * self.head_dim,\n            bias=config.text_config.attention_bias,\n        )\n        self.o_proj = nn.Linear(\n            self.num_q_heads * self.head_dim,\n            config.text_config.hidden_size,\n            bias=config.text_config.attention_bias,\n        )\n\n        self.attn_temperature_tuning = config.text_config.attn_temperature_tuning\n        self.use_rope = config.text_config.no_rope_layers[layer_idx]\n\n        self.layer_idx = layer_idx\n\n        self.rope_theta = config.text_config.rope_theta\n        self.rope_scaling = config.text_config.rope_scaling\n        self.rope_scaling[\"rope_type\"] = \"llama4\"\n\n        self.use_qk_norm = config.text_config.use_qk_norm\n        self.rms_norm_eps = config.text_config.rms_norm_eps\n\n        self.q_norm = Llama4TextL2Norm(self.rms_norm_eps, self.head_dim)\n        self.k_norm = Llama4TextL2Norm(self.rms_norm_eps, self.head_dim)\n\n    def forward(  # pylint: disable=too-many-locals\n        self,\n        hidden_states: Tensor,\n        paged_kv_cache: PagedKVCache,\n        layer_id: int,\n        cache_position,\n    ):\n        d, h_q = self.head_dim, self.num_q_heads\n        b, s, _ = hidden_states.shape\n        # QKV Projection\n        query_states = op.reshape(self.q_proj(hidden_states), (b, s, -1, d))\n        key_states = op.reshape(self.k_proj(hidden_states), (b, s, -1, d))\n        value_states = op.reshape(self.v_proj(hidden_states), (b, s, -1, d))\n\n        if self.use_rope:\n            qkv = op.concat([query_states, key_states, value_states], dim=2)\n\n            apply_rope = tvm.tir.IntImm(\"int64\", 1)\n\n            rotary_emb = position_embedding.llama4_rope_with_position_map(\n                theta=self.rope_theta,\n                scale=1.0,\n                head_dim=self.head_dim,\n                num_q_heads=self.num_q_heads,\n                num_kv_heads=self.num_kv_heads,\n                dtype=query_states.dtype,\n                rope_scaling=self.rope_scaling,\n            )\n\n            query_states, key_states, value_states = op.tensor_ir_op(\n                rotary_emb,\n                \"llama4_rope_with_position_map\",\n                args=[op.squeeze(qkv, axis=0), cache_position, apply_rope],\n                out=(\n                    Tensor.placeholder((s, h_q, d), query_states.dtype),\n                    Tensor.placeholder((s, self.num_kv_heads, d), query_states.dtype),\n                    Tensor.placeholder((s, self.num_kv_heads, d), query_states.dtype),\n                ),\n            )\n            query_states = query_states.reshape(b, s, h_q, d)\n            key_states = key_states.reshape(b, s, self.num_kv_heads, d)\n            value_states = value_states.reshape(b, s, self.num_kv_heads, d)\n\n        if self.use_qk_norm and self.use_rope:\n            query_states = self.q_norm(query_states)\n            key_states = self.k_norm(key_states)\n\n        if self.attn_temperature_tuning and not self.use_rope:\n            attn_scales = (\n                op.log(\n                    op.floor(\n                        (op.astype(cache_position, query_states.dtype) + 1.0) / self.floor_scale\n                    )\n                    + 1.0\n                )\n                * self.attn_scale\n                + 1.0\n            )\n\n            attn_scales = op.broadcast_to(attn_scales.reshape(1, s, 1, 1), (b, s, 1, 1))\n            query_states = query_states * attn_scales\n\n        qkv = op.concat([query_states, key_states, value_states], dim=2)\n\n        # Attention\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_q_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        return self.o_proj(output)\n\n\nclass Llama4TextExperts(nn.Module):\n    def __init__(self, config: Llama4Config):\n        self.num_experts = config.text_config.num_local_experts\n        self.intermediate_size = (\n            config.text_config.intermediate_size // config.tensor_parallel_shards\n        )\n        self.hidden_size = config.text_config.hidden_size\n        self.expert_dim = self.intermediate_size\n\n        self.gate_up_proj = nn.Parameter(\n            shape=(self.num_experts, self.hidden_size, 2 * self.expert_dim)\n        )\n        self.down_proj = nn.Parameter(shape=(self.num_experts, self.expert_dim, self.hidden_size))\n        self.act_fn = ACT2FN[config.text_config.hidden_act]\n\n    def forward(self, hidden_states):\n        hidden_states = hidden_states.reshape(self.gate_up_proj.shape[0], -1, self.hidden_size)\n        gate_up = op.matmul(hidden_states, self.gate_up_proj)\n        gate, up = op.chunk(gate_up, chunks=2, dim=-1)\n        next_states = op.matmul((up * self.act_fn(gate)), self.down_proj)\n        next_states = next_states.reshape(-1, self.hidden_size)\n        return next_states\n\n\nclass Llama4Router(nn.Module):\n    def __init__(self, config: Llama4Config):\n        self.num_experts = config.text_config.num_local_experts\n        self.top_k = config.text_config.num_experts_per_tok\n        self.intermediate_size = self.num_experts // config.tensor_parallel_shards\n        self.router = nn.Linear(\n            in_features=config.text_config.hidden_size,\n            out_features=self.intermediate_size,\n            bias=False,\n        )\n\n    def forward(self, hidden_states):\n        router_logits = self.router(hidden_states)\n        router_top_value, router_indices = op_ext.moe_misc.gating_topk(router_logits, self.top_k)\n\n        j_axis = op.arange(0, self.num_experts)\n        j_axis = op.unsqueeze(j_axis, 0)\n        idx_exp = op.unsqueeze(router_indices, -1)\n        mask = op.equal(idx_exp, j_axis)\n        val_exp = op.unsqueeze(router_top_value, -1)\n        neg_inf = op.full(mask.shape, -1e9, dtype=hidden_states.dtype)\n        masked_vals = op.where(mask, val_exp, neg_inf)\n        router_scores = op.max(masked_vals, axis=1)\n\n        router_scores = op.sigmoid(router_scores)\n        return router_scores, router_logits\n\n\nclass Llama4TextMoe(nn.Module):\n    def __init__(self, config: Llama4Config):\n        self.top_k = config.text_config.num_experts_per_tok\n        self.hidden_dim = config.text_config.hidden_size\n        self.num_experts = config.text_config.num_local_experts\n        self.experts = Llama4TextExperts(config)\n        self.router = Llama4Router(config)\n        self.shared_expert = Llama4TextMLP(config)\n\n    def forward(self, hidden_states):\n        hidden_states = hidden_states.reshape(-1, self.hidden_dim)\n        router_scores, _ = self.router(hidden_states)\n\n        routed_in = op.broadcast_to(\n            hidden_states.reshape(1, *hidden_states.shape),\n            [router_scores.shape[1], *hidden_states.shape],\n        )\n        routed_in = routed_in.reshape(-1, self.hidden_dim)\n\n        routed_in = routed_in * op.permute_dims(router_scores, axes=[1, 0]).reshape(-1, 1)\n\n        routed_out = self.experts(routed_in)\n        out = self.shared_expert(hidden_states)\n\n        out += op.sum(routed_out.reshape(router_scores.shape[1], -1, routed_out.shape[-1]), axis=0)\n\n        return out\n\n\nclass Llama4TextDecoderLayer(nn.Module):\n    def __init__(self, config: Llama4Config, layer_idx):\n        rms_norm_eps = config.text_config.rms_norm_eps\n        self.self_attn = Llama4TextAttention(config, layer_idx)\n        self.is_moe_layer = layer_idx in config.text_config.moe_layers\n        if self.is_moe_layer:  # the 128E model interleaves dense / sparse\n            self.feed_forward = Llama4TextMoe(config)\n        else:\n            self.feed_forward = Llama4TextMLP(config)\n\n        self.input_layernorm = nn.RMSNorm(\n            config.text_config.hidden_size, -1, rms_norm_eps, bias=False\n        )\n        self.post_attention_layernorm = nn.RMSNorm(\n            config.text_config.hidden_size, -1, rms_norm_eps, bias=False\n        )\n\n        def _set_tp():\n            def _set(layer, hint):\n                if hasattr(layer, \"weight\"):\n                    layer.weight.attrs[\"shard_strategy\"] = hint\n                else:\n                    layer.attrs[\"shard_strategy\"] = hint\n\n            _set(self.self_attn.q_proj, tp.ShardSingleDim(\"_shard_q\", dim=0))\n            _set(self.self_attn.k_proj, tp.ShardSingleDim(\"_shard_k\", dim=0))\n            _set(self.self_attn.v_proj, tp.ShardSingleDim(\"_shard_v\", dim=0))\n            _set(self.self_attn.o_proj, tp.ShardSingleDim(\"_shard_o\", dim=1))\n\n            if isinstance(self.feed_forward, Llama4TextMLP):\n                i = self.feed_forward.intermediate_size\n                _set(\n                    self.feed_forward.gate_up_proj,\n                    tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n                )\n                _set(\n                    self.feed_forward.down_proj,\n                    tp.ShardSingleDim(\"_shard_mlp_down\", dim=1),\n                )\n            else:\n                assert isinstance(self.feed_forward, Llama4TextMoe)\n                i = self.feed_forward.shared_expert.intermediate_size\n                _set(\n                    self.feed_forward.shared_expert.gate_up_proj,\n                    tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n                )\n                _set(\n                    self.feed_forward.shared_expert.down_proj,\n                    tp.ShardSingleDim(\"_shard_mlp_down\", dim=1),\n                )\n\n                j = self.feed_forward.experts.intermediate_size\n                _set(\n                    self.feed_forward.experts.gate_up_proj,\n                    tp.ShardSingleDim(\"_shard_expert_mlp_up\", segs=[j, j], dim=2),\n                )\n                _set(\n                    self.feed_forward.experts.down_proj,\n                    tp.ShardSingleDim(\"_shard_expert_mlp_down\", dim=1),\n                )\n\n                _set(\n                    self.feed_forward.router.router,\n                    tp.ShardSingleDim(\"_shard_router\", dim=0),\n                )\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(\n        self,\n        hidden_states: Tensor,\n        paged_kv_cache: PagedKVCache,\n        layer_id: int,\n        cache_position,\n    ):\n        out = self.self_attn(\n            self.input_layernorm(hidden_states),\n            paged_kv_cache,\n            layer_id,\n            cache_position,\n        )\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.feed_forward(self.post_attention_layernorm(hidden_states))\n\n        hidden_states = self._apply_residual(\n            op.reshape(out, hidden_states.shape), residual=hidden_states\n        )\n\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass Llama4TextModel(nn.Module):\n    def __init__(self, config: Llama4Config):\n        assert config.text_config.hidden_size % config.text_config.num_attention_heads == 0\n        self.embed_tokens = LlamaEmbedding(\"vocab_size\", config.text_config.hidden_size)\n        self.layers = nn.ModuleList(\n            [\n                Llama4TextDecoderLayer(config, layer_idx)\n                for layer_idx in range(config.text_config.num_hidden_layers)\n            ]\n        )\n        self.norm = nn.RMSNorm(\n            config.text_config.hidden_size,\n            -1,\n            config.text_config.rms_norm_eps,\n            bias=False,\n        )\n\n    def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = input_embed\n        cache_position = paged_kv_cache.get_query_positions(\n            input_embed.shape[0] * input_embed.shape[1]\n        )\n\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id, cache_position)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass Llama4ForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: Llama4Config):\n        self.text_config = config.text_config\n        self.model = Llama4TextModel(config)\n        self.tie_word_embeddings = self.text_config.tie_word_embeddings\n        if not self.text_config.tie_word_embeddings:\n            self.lm_head = nn.Linear(self.text_config.hidden_size, \"vocab_size\", bias=False)\n        self.num_hidden_layers = self.text_config.num_hidden_layers\n        self.num_attention_heads = self.text_config.num_attention_heads\n        self.num_key_value_heads = self.text_config.num_key_value_heads\n        self.head_dim = self.text_config.head_dim\n        self.hidden_size = self.text_config.hidden_size\n        self.vocab_size = self.text_config.vocab_size\n        self.rope_scaling = self.text_config.rope_scaling\n        self.rope_theta = self.text_config.position_embedding_base\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.disaggregation = config.disaggregation\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            if self.tensor_parallel_shards > 1:\n                logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        return self.get_logits(hidden_states)\n\n    def batch_forward_to_last_hidden_states(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        return hidden_states\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def get_logits(self, hidden_states: Tensor):\n        op_ext.configure()\n        if self.tie_word_embeddings:\n            logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor):\n        op_ext.configure()\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        hidden_states = op.take(hidden_states, logit_positions, axis=0)\n        return hidden_states\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.get_logits(hidden_states)\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.get_logits(hidden_states)\n        return logits, paged_kv_cache\n\n    def prefill_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def decode_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_prefill_to_last_hidden_states(\n        self, input_embeds: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def batch_decode_to_last_hidden_states(\n        self, input_embeds: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def batch_verify_to_last_hidden_states(\n        self, input_embeds: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NONE,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            rope_scaling=self.rope_scaling,\n            enable_disaggregation=self.disaggregation,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"get_logits\": {\n                \"hidden_states\": nn.spec.Tensor([\"seq_len\", self.hidden_size], self.dtype),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_select_last_hidden_states\": {\n                \"hidden_states\": nn.spec.Tensor([\"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill_to_last_hidden_states\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode_to_last_hidden_states\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill_to_last_hidden_states\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode_to_last_hidden_states\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify_to_last_hidden_states\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/llava/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/llava/llava_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Llava parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\nfrom mlc_llm.quantization import Quantization, make_awq_quant\n\nfrom .llava_model import LlavaConfig, LlavaForCausalLM\n\nawq_quant = make_awq_quant(LlavaForCausalLM)\n\n\ndef _num_layers(config: object) -> int:\n    return config.text_config.num_hidden_layers  # type: ignore[attr-defined]\n\n\nhuggingface = make_standard_hf_loader(\n    model_cls=LlavaForCausalLM,\n    layer_prefix=\"language_model.model.layers\",\n    add_unused=[\"rotary_emb.inv_freq\"],\n    num_layers_getter=_num_layers,\n)\n\n\ndef awq(model_config: LlavaConfig, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of AWQ parameters.\n    Parameters\n    ----------\n    model_config : LlavaConfig\n        The configuration of the Llava model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to AWQ.\n    \"\"\"\n    model, _ = awq_quant(model_config, quantization)\n    _, _named_params = model.export_tvm(spec=model.get_default_spec())\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    for i in range(model_config.text_config.num_hidden_layers):\n        # Add QKV in self attention\n        attn = f\"language_model.model.layers.{i}.self_attn\"\n        for quantize_suffix in [\"qweight\", \"qzeros\", \"scales\"]:\n            mlc_name = f\"{attn}.qkv_proj.{quantize_suffix}\"\n            assert mlc_name in named_parameters\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{attn}.q_proj.{quantize_suffix}\",\n                    f\"{attn}.k_proj.{quantize_suffix}\",\n                    f\"{attn}.v_proj.{quantize_suffix}\",\n                ],\n                functools.partial(\n                    lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n        # Concat gate and up in MLP\n        mlp = f\"language_model.model.layers.{i}.mlp\"\n        for quantize_suffix in [\"qweight\", \"qzeros\", \"scales\"]:\n            mlc_name = f\"{mlp}.gate_up_proj.{quantize_suffix}\"\n            assert mlc_name in named_parameters\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{mlp}.gate_proj.{quantize_suffix}\",\n                    f\"{mlp}.up_proj.{quantize_suffix}\",\n                ],\n                functools.partial(\n                    lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n        # inv_freq is not used in the model\n        mapping.add_unused(f\"{attn}.rotary_emb.inv_freq\")\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype),\n            )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/llava/llava_model.py",
    "content": "\"\"\"\nImplementation of LLaVa Model\nImplements the CLIP Vision Encoder. Uses Llama for the Language Encoder.\n\"\"\"\n\nimport dataclasses\nimport logging\nfrom typing import Any, Dict, Optional\n\nfrom tvm import tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Module, Tensor\nfrom tvm.relax.frontend.nn.op import permute_dims, reshape, wrap_nested\nfrom tvm.relax.op import strided_slice\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.model.model_preset import MODEL_PRESETS\nfrom mlc_llm.model.vision import CLIPVisionConfig, CLIPVisionModel, ImageProcessor\nfrom mlc_llm.nn import PagedKVCache, RopeMode\n\nfrom ...support.config import ConfigBase\nfrom ..llama.llama_model import LlamaConfig, LlamaForCausalLM\nfrom ..mistral.mistral_model import MistralConfig, MistralForCausalLM\n\nlogger = logging.getLogger(__name__)\n\n\nCONFIG_MAP = {\"LlamaForCausalLM\": LlamaConfig, \"MistralForCausalLM\": MistralConfig}\nARCHITECTURE_MAP = {\n    \"LlamaForCausalLM\": LlamaForCausalLM,\n    \"MistralForCausalLM\": MistralForCausalLM,\n}\n\n\n@dataclasses.dataclass\nclass LlavaConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"\n    LLaVa Config\n    \"\"\"\n\n    image_token_index: int\n    text_config: LlamaConfig\n    vision_config: CLIPVisionConfig\n    vocab_size: int\n    context_window_size: int = -1\n    sliding_window_size: int = -1\n    prefill_chunk_size: int = -1\n    tensor_parallel_shards: int = 1\n    max_batch_size: int = 1\n    text_architecture: str = \"LlamaForCausalLM\"\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self) -> None:\n        vision_config_dict: Dict[str, Any]\n        if isinstance(self.vision_config, CLIPVisionConfig):\n            vision_config_dict = dataclasses.asdict(self.vision_config)\n        else:\n            vision_config_dict = dict(self.vision_config)\n\n        for k, v in vision_config_dict.pop(\"kwargs\", {}).items():\n            vision_config_dict[k] = v\n\n        self.vision_config = CLIPVisionConfig.from_dict(vision_config_dict)\n\n        text_config_dict: Dict[str, Any]\n        if isinstance(self.text_config, ConfigBase):\n            text_config_dict = dataclasses.asdict(self.text_config)\n        else:\n            text_config_dict = dict(self.text_config)\n\n        if \"_name_or_path\" in text_config_dict:\n            hf_config = self.get_hf_config(text_config_dict)\n            text_config_dict.update(hf_config)\n            architectures = text_config_dict[\"architectures\"]\n            assert len(architectures) == 1\n            self.text_architecture = architectures[0]\n        else:\n            for k, v in text_config_dict.pop(\"kwargs\", {}).items():\n                text_config_dict[k] = v\n\n        self.text_config = CONFIG_MAP[self.text_architecture].from_dict(  # type: ignore\n            text_config_dict\n        )\n\n        for k in [\"context_window_size\", \"sliding_window_size\", \"prefill_chunk_size\"]:\n            if getattr(self, k) <= 0:\n                if hasattr(self.text_config, k):\n                    setattr(self, k, getattr(self.text_config, k))\n\n    def get_hf_config(self, text_config_dict: Dict[str, Any]) -> Dict[str, Any]:\n        \"\"\"\n        Get the Hugging Face config of the text model\n        \"\"\"\n\n        hf_config: Dict[str, Any]\n        try:\n            # pylint: disable=import-outside-toplevel, import-error\n            from transformers import AutoConfig\n\n            hf_config = AutoConfig.from_pretrained(text_config_dict[\"_name_or_path\"]).to_dict()\n        except (ImportError, OSError) as e:\n            # If transformers is not installed, get the config from preset\n            # Llama2 is gated so it throws an OSError. Get the config from preset instead\n            preset_mapping = {\n                \"meta-llama/Llama-2-7b-hf\": \"llama2_7b\",\n                \"meta-llama/Llama-2-13b-hf\": \"llama2_13b\",\n                \"lmsys/vicuna-7b-v1.5\": \"llama2_7b\",\n                \"mistralai/Mistral-7B-v0.1\": \"mistral_7b\",\n            }\n            if text_config_dict[\"_name_or_path\"] in preset_mapping:\n                hf_config = MODEL_PRESETS[preset_mapping[text_config_dict[\"_name_or_path\"]]]\n            else:\n                raise ValueError(\"Unsupported text model\") from e\n\n        return hf_config\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass LlavaMultiModalProjector(nn.Module):\n    def __init__(self, config: LlavaConfig):\n        super().__init__()\n\n        self.linear_1 = nn.Linear(\n            config.vision_config.hidden_size, config.text_config.hidden_size, bias=True\n        )\n        self.act = nn.GELU()\n        self.linear_2 = nn.Linear(\n            config.text_config.hidden_size, config.text_config.hidden_size, bias=True\n        )\n\n    def forward(self, image_features: Tensor) -> Tensor:\n        hidden_states = self.linear_1(image_features)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.linear_2(hidden_states)\n        return hidden_states\n\n\nclass LlavaForCausalLM(Module):\n    def __init__(self, config: LlavaConfig):\n        super().__init__()\n        self.config = config\n        self.vision_tower = CLIPVisionModel(config.vision_config)\n        self.image_processor = ImageProcessor()\n        self.multi_modal_projector = LlavaMultiModalProjector(config)\n        self.language_model = ARCHITECTURE_MAP[config.text_architecture](config.text_config)\n        self.vocab_size = config.vocab_size\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        self.language_model.to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def embed(self, input_ids: Tensor) -> Tensor:\n        return self.language_model.embed(input_ids)\n\n    def image_preprocess(self, pixel_values: Tensor) -> Tensor:\n        pixel_values = permute_dims(pixel_values, axes=(0, 3, 1, 2))  # NHWC -> NCHW\n        pixel_values = self.image_processor.resize(\n            pixel_values,\n            {\n                \"shortest_edge\": self.config.vision_config.image_size,\n            },\n        )\n        pixel_values = self.image_processor.crop(\n            pixel_values,\n            {\n                \"height\": self.config.vision_config.image_size,\n                \"width\": self.config.vision_config.image_size,\n            },\n        )\n        pixel_values = self.image_processor.rescale(pixel_values)\n        pixel_values = self.image_processor.normalize(pixel_values)\n        return pixel_values\n\n    def image_embed(self, pixel_values: Tensor) -> Tensor:\n        pixel_values = self.image_preprocess(pixel_values)\n        pixel_values = pixel_values.astype(self.dtype)\n        image_features_all = self.vision_tower.forward(pixel_values)\n        image_features = wrap_nested(\n            strided_slice(\n                image_features_all._expr,  # pylint: disable=protected-access\n                axes=[1],\n                begin=[1],\n                end=[image_features_all.shape[1]],\n            ),\n            name=\"slice\",\n        )\n        image_features = self.multi_modal_projector(image_features)\n        image_features = reshape(image_features, shape=(-1, self.config.text_config.hidden_size))\n        return image_features\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        return self.language_model.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        return self.language_model.prefill(input_embed, paged_kv_cache)\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        return self.language_model.decode(input_embed, paged_kv_cache)\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        return self.language_model.batch_prefill(input_embeds, logit_positions, paged_kv_cache)\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        return self.language_model.batch_decode(input_embeds, paged_kv_cache)\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        return self.language_model.batch_verify(input_embeds, paged_kv_cache)\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.config.text_config.num_hidden_layers,\n            num_attention_heads=self.config.text_config.num_attention_heads\n            // self.config.tensor_parallel_shards,\n            num_key_value_heads=self.config.text_config.num_key_value_heads\n            // self.config.tensor_parallel_shards,\n            qk_head_dim=self.config.text_config.head_dim,\n            v_head_dim=self.config.text_config.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.language_model.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"image_embed\": {\n                \"pixel_values\": nn.spec.Tensor(\n                    [1, \"image_height\", \"image_width\", 3],\n                    \"uint8\",\n                ),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor(\n                    [1, \"seq_len\", self.config.text_config.hidden_size], self.dtype\n                ),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor(\n                    [1, 1, self.config.text_config.hidden_size], self.dtype\n                ),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor(\n                    [1, \"seq_len\", self.config.text_config.hidden_size], self.dtype\n                ),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor(\n                    [\"batch_size\", 1, self.config.text_config.hidden_size], self.dtype\n                ),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor(\n                    [1, \"seq_len\", self.config.text_config.hidden_size], self.dtype\n                ),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/medusa/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/medusa/medusa_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Medusa parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\n\nfrom .medusa_model import MedusaModel\n\nhuggingface = make_standard_hf_loader(\n    model_cls=MedusaModel,\n    include_qkv=False,\n    include_gate_up=False,\n)\n"
  },
  {
    "path": "python/mlc_llm/model/medusa/medusa_model.py",
    "content": "\"\"\"Medusa model definition.\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm.relax.frontend import nn\n\nfrom mlc_llm.support import logging\nfrom mlc_llm.support.config import ConfigBase\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass MedusaConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Llama model.\"\"\"\n\n    medusa_num_heads: int\n    medusa_num_layers: int\n    hidden_size: int\n    vocab_size: int\n    max_batch_size: int = 1\n    tensor_parallel_shards: int = 1\n\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    # Unused parameters. Kept for compatibility with the compilation flow.\n    prefill_chunk_size: int = -1\n    context_window_size: int = -1\n\n\n# pylint: disable=missing-docstring\n\n\nclass ResBlock(nn.Module):\n    \"\"\"Residual block with SiLU activation.\"\"\"\n\n    def __init__(self, hidden_size):\n        super().__init__()\n        self.linear = nn.Linear(hidden_size, hidden_size)\n        self.act = nn.SiLU()\n\n    def forward(self, x):\n        return x + self.act(self.linear(x))\n\n\nclass MedusaModel(nn.Module):\n    \"\"\"Medusa model definition.\"\"\"\n\n    def __init__(self, config: MedusaConfig):\n        self.hidden_size = config.hidden_size\n        self.dtype = \"float32\"\n        self.medusa_head = nn.ModuleList(\n            [\n                nn.ModuleList(\n                    [ResBlock(config.hidden_size) for _ in range(config.medusa_num_layers)]\n                    + [nn.Linear(config.hidden_size, config.vocab_size, bias=False)]\n                )\n                for _ in range(config.medusa_num_heads)\n            ]\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"get_logits\": {\n                \"hidden_states\": nn.spec.Tensor([\"batch_size\", self.hidden_size], self.dtype),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n\n    def get_logits(self, hidden_states: nn.Tensor):\n        logits = []\n        for head in self.medusa_head:\n            logits.append(head(hidden_states).astype(\"float32\"))\n        return logits\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n"
  },
  {
    "path": "python/mlc_llm/model/minicpm/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/minicpm/minicpm_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's MiniCPM parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.quantization import Quantization\n\nfrom .minicpm_model import MiniCPMConfig, MiniCPMForCausalLM\n\n\ndef huggingface(model_config: MiniCPMConfig, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : MiniCPMConfig\n        The configuration of the MiniCPM model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = MiniCPMForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    for i in range(model_config.num_hidden_layers):\n        # map attention weight\n        attn = f\"model.layers.{i}.self_attn\"\n        for weight_type in [\"weight\"]:\n            mlc_name = f\"{attn}.wqkv_pack.{weight_type}\"\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{attn}.q_proj.{weight_type}\",\n                    f\"{attn}.k_proj.{weight_type}\",\n                    f\"{attn}.v_proj.{weight_type}\",\n                ],\n                functools.partial(\n                    lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n    if model_config.num_experts == 0:\n        for i in range(model_config.num_hidden_layers):\n            # map mlp weight\n            mlp = f\"model.layers.{i}.mlp\"\n            mlc_name = f\"{mlp}.gate_up_proj.weight\"\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{mlp}.gate_proj.weight\",\n                    f\"{mlp}.up_proj.weight\",\n                ],\n                functools.partial(\n                    lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n    else:\n        for i in range(model_config.num_hidden_layers):\n            # map mlp weight\n            mlp = f\"model.layers.{i}.mlp\"\n            mlc_mlp = f\"model.layers.{i}.mlp\"\n            mlc_name = f\"{mlc_mlp}.e1_e3.weight\"\n            mlc_param = named_parameters[mlc_name]\n\n            def combine_expert_gate_up(*hf_params, dtype):\n                stack = []\n                for i in range(0, len(hf_params), 2):\n                    stack.append(np.concatenate([hf_params[i], hf_params[i + 1]], axis=0))\n                return np.stack(stack, axis=0).astype(dtype)\n\n            mapping.add_mapping(\n                mlc_name,\n                functools.reduce(\n                    lambda a, b: a + b,\n                    [\n                        [\n                            f\"{mlp}.experts.{expert_id}.w1.weight\",\n                            f\"{mlp}.experts.{expert_id}.w3.weight\",\n                        ]\n                        for expert_id in range(model_config.num_experts)\n                    ],\n                ),\n                functools.partial(\n                    combine_expert_gate_up,\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n            mlc_name = f\"{mlc_mlp}.e2.weight\"\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{mlp}.experts.{expert_id}.w2.weight\"\n                    for expert_id in range(model_config.num_experts)\n                ],\n                functools.partial(\n                    lambda *hf_params, dtype: np.stack(hf_params, axis=0).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n            mlc_name = f\"{mlc_mlp}.gate.weight\"\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [f\"{mlp}.gate.weight\"],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n    for mlc_name, mlc_param in named_parameters.items():\n        # Skip lm_head.weight if tie_word_embeddings is enabled\n        if mlc_name == \"lm_head.weight\" and model_config.tie_word_embeddings:\n            continue\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/minicpm/minicpm_model.py",
    "content": "\"\"\"\nImplementation for Minicpm architecture.\n\"\"\"\n\nimport dataclasses\nimport math\nfrom functools import partial\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.nn.expert import MixtralExperts\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass MiniCPMConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the MiniCPM model.\"\"\"\n\n    vocab_size: int\n    hidden_size: int\n    num_hidden_layers: int\n    num_attention_heads: int\n    num_key_value_heads: int\n    hidden_act: str\n    rms_norm_eps: float\n    intermediate_size: int\n    scale_emb: int\n    scale_depth: float\n    dim_model_base: int\n    use_cache: bool\n    bos_token_id: int\n    eos_token_id: int\n    tie_word_embeddings: bool = False\n    rope_theta: int = 10000\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    head_dim: int = 0\n    max_batch_size: int = 1\n    num_experts_per_tok: int = 0\n    num_experts: int = 0\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.head_dim * self.num_attention_heads == self.hidden_size\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass MiniCPMAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: MiniCPMConfig):\n        super().__init__()  # Make sure to call the parent class constructor\n        self.hidden_size = config.hidden_size\n        self.rope_theta = config.rope_theta\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        if config.num_attention_heads % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.num_attention_heads} attention heads \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n\n        self.num_heads = config.num_attention_heads // self.tensor_parallel_shards\n        self.head_dim = config.head_dim\n        self.num_key_value_heads = config.num_key_value_heads // self.tensor_parallel_shards\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.max_position_embeddings = config.context_window_size\n\n        self.wqkv_pack = nn.Linear(\n            in_features=self.hidden_size,\n            out_features=(self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,\n            bias=False,\n        )\n        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q, h_kv = self.head_dim, self.num_heads, self.num_key_value_heads\n        b, s, _ = hidden_states.shape\n        qkv = self.wqkv_pack(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        attn_output = self.o_proj(output)\n        return attn_output\n\n\nACT2FN = {\n    \"gelu\": partial(nn.gelu, approximate=False),\n    \"relu\": nn.relu,\n    \"silu\": nn.silu,\n    \"swish\": nn.silu,\n    \"gelu_new\": partial(nn.gelu, approximate=True),\n}\n\n\nclass MiniCPMEmbedding(nn.Embedding):\n    \"\"\"The embedding module specialized for MiniCPM so that\n    it can be shared with the final lm_head.\n    \"\"\"\n\n    def lm_head_forward(self, x: nn.Tensor):\n        \"\"\"The lm_head forwarding, which transposes the weight and multiplies\n        with the input tensor.\n        \"\"\"\n        weight = nn.op.permute_dims(self.weight)\n        return nn.op.matmul(x, weight, out_dtype=\"float32\")\n\n\nclass MiniCPMMLP(nn.Module):\n    def __init__(self, config: MiniCPMConfig):\n        self.hidden_size = config.hidden_size\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n\n        self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x: Tensor):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(op.silu(x1) * x2)\n\n\nclass MiniCPMMoE(nn.Module):\n    def __init__(self, config: MiniCPMConfig):\n        self.num_local_experts = config.num_experts\n        self.num_experts_per_tok = config.num_experts_per_tok\n        self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.e1_e3 = MixtralExperts(\n            self.num_local_experts,\n            in_features=config.hidden_size,\n            out_features=2 * self.intermediate_size,\n            tensor_parallel_shards=config.tensor_parallel_shards,\n        )\n        self.e2 = MixtralExperts(\n            self.num_local_experts,\n            in_features=self.intermediate_size,\n            out_features=config.hidden_size,\n            tensor_parallel_shards=config.tensor_parallel_shards,\n        )\n        self.dtype = \"float32\"\n\n    def forward(self, x: Tensor):  # pylint: disable=too-many-locals\n        def _expert_forward(x: Tensor, indptr: Tensor):\n            x1_x3 = self.e1_e3(x, indptr)\n            x1, x3 = op.split(x1_x3, indices_or_sections=2, axis=-1)\n            x = self.e2(op.silu(x1) * x3, indptr)\n            return x\n\n        experts_per_tok = self.num_experts_per_tok  # activated experts per token\n        local_experts = self.num_local_experts  # total number of experts\n        batch_size, seq_len, hidden_size = x.shape\n        num_tokens = batch_size * seq_len\n        x = x.reshape(num_tokens, hidden_size)\n        # gate: [num_tokens, local_experts]\n        gate: Tensor = self.gate(x)\n        # expert_weights: [num_tokens, experts_per_tok]\n        # expert_indices: [num_tokens, experts_per_tok]\n        expert_weights, expert_indices = op_ext.moe_misc.gating_softmax_topk(gate, experts_per_tok)\n        use_ft = (\n            op_ext.get_store().cutlass_group_gemm or op_ext.get_store().faster_transformer\n        ) and self.dtype == \"float16\"\n        if num_tokens == 1:\n            # x: [num_tokens * experts_per_tok, hidden_size]\n            x = _expert_forward(x, expert_indices)\n        else:\n            # cumsum: [num_tokens * local_experts]\n            cumsum = op_ext.moe_misc.moe_cumsum(expert_indices, local_experts)\n            # indices: [num_tokens * experts_per_tok]\n            reverse_indices, token_indices = op_ext.moe_misc.get_indices(cumsum, expert_indices)\n            if use_ft:\n                # indptr: [num_local_experts]\n                indptr = op_ext.moe_misc.get_indptr(\n                    cumsum, local_experts, num_tokens, inclusive=True, out_dtype=\"int64\"\n                )\n            else:\n                # indptr: [num_local_experts + 1]\n                indptr = op_ext.moe_misc.get_indptr(\n                    cumsum,\n                    local_experts,\n                    num_tokens,\n                    inclusive=False,\n                    out_dtype=\"int32\",\n                )\n            # x: [num_tokens * experts_per_tok, hidden_size]\n            x = op.take(x, token_indices, axis=0)\n            x = _expert_forward(x, indptr)\n            x = op_ext.moe_misc.scatter_output(x, reverse_indices)\n        # x: [num_tokens, experts_per_tok, hidden_size]\n        x = x.reshape(  # pylint: disable=too-many-function-args\n            num_tokens, experts_per_tok, hidden_size\n        ) * expert_weights.reshape(  # pylint: disable=too-many-function-args\n            num_tokens, experts_per_tok, 1\n        )\n        # x: [num_tokens, hidden_size]\n        x = op_ext.moe_misc.moe_sum(x, dim=1)\n        x = x.reshape(batch_size, seq_len, hidden_size)  # pylint: disable=too-many-function-args\n        return x\n\n\nclass MiniCPMDecoderLayer(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: MiniCPMConfig):\n        self.scale_depth = config.scale_depth\n        self.hidden_size = config.hidden_size\n        self.num_hidden_layers = config.num_hidden_layers\n        self.self_attn = MiniCPMAttention(config)\n        self.num_experts = config.num_experts\n        if self.num_experts == 0:\n            self.mlp = MiniCPMMLP(config)\n        else:\n            self.mlp = MiniCPMMoE(config)\n        self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n        self.post_attention_layernorm = nn.RMSNorm(\n            config.hidden_size, -1, config.rms_norm_eps, bias=False\n        )\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_heads * hd\n            k = self.self_attn.num_key_value_heads * hd\n            v = self.self_attn.num_key_value_heads * hd\n            i = self.mlp.intermediate_size\n            _set(\n                self.self_attn.wqkv_pack.weight,\n                tp.ShardSingleDim(\"_shard_qkv_weight\", dim=0, segs=[q, k, v]),\n            )\n            _set(self.self_attn.o_proj.weight, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            if self.num_experts == 0:\n                _set(\n                    self.mlp.gate_up_proj.weight,\n                    tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n                )\n                _set(\n                    self.mlp.down_proj.weight,\n                    tp.ShardSingleDim(\"_shard_mlp_down\", dim=1),\n                )\n            else:\n                _set(\n                    self.mlp.e1_e3.weight,\n                    tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=1),\n                )\n                _set(self.mlp.e2.weight, tp.ShardSingleDim(\"_shard_mlp_down\", dim=2))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        residual = hidden_states\n        hidden_states = self.input_layernorm(hidden_states)\n        hidden_states = self.self_attn(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(\n            hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers)),\n            residual,\n        )\n        residual = hidden_states\n        hidden_states = self.post_attention_layernorm(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = self._apply_residual(\n            hidden_states * (self.scale_depth / math.sqrt(self.num_hidden_layers)),\n            residual,\n        )\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass MiniCPMModel(nn.Module):\n    def __init__(self, config: MiniCPMConfig):\n        assert config.hidden_size % config.num_attention_heads == 0\n        self.embed_tokens = MiniCPMEmbedding(config.vocab_size, config.hidden_size)\n        self.layers = nn.ModuleList(\n            [MiniCPMDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n\n    def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = inputs\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass MiniCPMForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: MiniCPMConfig):\n        self.model = MiniCPMModel(config)\n        self.tie_word_embeddings = config.tie_word_embeddings\n        if not config.tie_word_embeddings:\n            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.vocab_size = config.vocab_size\n        self.num_hidden_layers = config.num_hidden_layers\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.head_dim = config.hidden_size // config.num_attention_heads\n        self.vocab_size = config.vocab_size\n        self.rope_theta = config.rope_theta\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.scale_emb = config.scale_emb\n        self.scale_width = self.hidden_size // config.dim_model_base\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache) / self.scale_width\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        if self.tie_word_embeddings:\n            logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids) * self.scale_emb\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache) / self.scale_width\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        if self.tie_word_embeddings:\n            logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache) / self.scale_width\n        if self.tie_word_embeddings:\n            logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/ministral3/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/ministral3/ministral3_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Ministral3 parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\nfrom typing import Callable, List, Optional, Tuple\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping, QuantizeMapping\nfrom mlc_llm.quantization import BlockScaleQuantize, Quantization\n\nfrom .ministral3_model import Ministral3Config, Mistral3ForConditionalGeneration\n\n\ndef _dequantize_block_scale_weight(  # pylint: disable=too-many-locals\n    weight: np.ndarray, weight_scale: np.ndarray, block_size: Tuple[int, int]\n) -> np.ndarray:\n    \"\"\"Reconstruct float weights from FP8 block-scale storage.\"\"\"\n\n    rows, cols = weight.shape\n    block_rows, block_cols = block_size\n    out = np.empty((rows, cols), dtype=\"float32\")\n    weight = weight.astype(\"float32\")\n    num_row_blocks, num_col_blocks = weight_scale.shape\n    for i in range(num_row_blocks):\n        row_start = i * block_rows\n        if row_start >= rows:\n            break\n        row_end = min(row_start + block_rows, rows)\n        scale_row = weight_scale[i]\n        for j in range(num_col_blocks):\n            col_start = j * block_cols\n            if col_start >= cols:\n                break\n            col_end = min(col_start + block_cols, cols)\n            out[row_start:row_end, col_start:col_end] = (\n                weight[row_start:row_end, col_start:col_end] * scale_row[j]\n            )\n    return out\n\n\ndef huggingface(  # pylint: disable=too-many-locals,too-many-statements\n    model_config: Ministral3Config, quantization: Quantization\n) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : Ministral3Config\n        The configuration of the Ministral3 model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = Mistral3ForConditionalGeneration(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    if isinstance(quantization, BlockScaleQuantize):\n        # Convert the model to block-scale quantized model before loading parameters\n        model = quantization.quantize_model(model, QuantizeMapping({}, {}), \"\")\n        if model_config.weight_block_size is None:\n            raise ValueError(\n                \"The input Ministral 3 model is not fp8 block quantized. \"\n                \"Thus BlockScaleQuantize is not supported.\"\n            )\n\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    raw_params = dict(_named_params)\n    if any(name.startswith(\"language_model.\") for name in raw_params):\n        named_parameters = {\n            name.replace(\"language_model.\", \"\", 1): value for name, value in raw_params.items()\n        }\n    else:\n        named_parameters = raw_params\n\n    mapping = ExternMapping()\n\n    hf_prefix = \"\"\n    if \"vision_config\" in model_config.kwargs:\n        hf_prefix = \"language_model.\"\n\n    def hf(name: str) -> str:\n        return f\"{hf_prefix}{name}\"\n\n    if (\n        not isinstance(quantization, BlockScaleQuantize)\n        and model_config.weight_block_size is not None\n    ):\n        raise ValueError(\n            \"The input Ministral 3 model is fp8 block quantized. \"\n            \"Please use BlockScaleQuantize for the model.\"\n        )\n\n    # Helper function to add both weight and scale mappings\n    def add_weight_and_scale_mapping(  # pylint: disable=too-many-locals\n        weight_mlc_name: str,\n        weight_hf_names: List[str],\n        weight_transform_func: Callable,\n        activation_transform_func: Optional[Callable] = None,\n    ):\n        mlc_param = named_parameters[weight_mlc_name]\n        mapping.add_mapping(\n            weight_mlc_name,\n            weight_hf_names,\n            functools.partial(weight_transform_func, dtype=mlc_param.dtype),\n        )\n\n        if isinstance(quantization, BlockScaleQuantize):\n            weight_scale_mlc_name = f\"{weight_mlc_name}_scale_inv\"\n            if weight_scale_mlc_name in named_parameters:\n                weight_scale_hf_names = [f\"{name}_scale_inv\" for name in weight_hf_names]\n                weight_scale_param = named_parameters[weight_scale_mlc_name]\n                expected_weight_scale_shape = tuple(int(dim) for dim in weight_scale_param.shape)\n\n                def _weight_scale_transform(*arrays, dtype: str, _transform=weight_transform_func):\n                    processed = []\n                    for arr in arrays:\n                        arr_np = np.asarray(arr)\n                        if arr_np.ndim == 0:\n                            arr_np = arr_np.reshape((1,))\n                        processed.append(arr_np)\n                    result = _transform(*processed, dtype=dtype)\n                    result = np.asarray(result, dtype=dtype)\n                    if result.shape == expected_weight_scale_shape:\n                        return result\n                    if result.shape == ():\n                        return np.full(expected_weight_scale_shape, result.item(), dtype=dtype)\n                    if result.shape == (1,) and expected_weight_scale_shape != (1,):\n                        return np.broadcast_to(result, expected_weight_scale_shape).astype(dtype)\n                    if (\n                        result.ndim == 1\n                        and result.size > 1\n                        and len(expected_weight_scale_shape) >= 2\n                        and expected_weight_scale_shape[0] % result.size == 0\n                    ):\n                        rows_per_segment = expected_weight_scale_shape[0] // result.size\n                        tiled = np.repeat(result, rows_per_segment)\n                        tiled = tiled.reshape(expected_weight_scale_shape[0], 1)\n                        return np.broadcast_to(tiled, expected_weight_scale_shape).astype(dtype)\n                    raise ValueError(\n                        f\"Unexpected weight scale shape {result.shape} for \"\n                        f\"{weight_scale_mlc_name}, expected {expected_weight_scale_shape}\"\n                    )\n\n                mapping.add_mapping(\n                    weight_scale_mlc_name,\n                    weight_scale_hf_names,\n                    functools.partial(_weight_scale_transform, dtype=weight_scale_param.dtype),\n                )\n            activation_scale_mlc_name = f\"{weight_mlc_name[: -len('.weight')]}.activation_scale\"\n            if activation_scale_mlc_name in named_parameters:\n                activation_scale_hf_names = [\n                    f\"{name[: -len('.weight')]}.activation_scale\" for name in weight_hf_names\n                ]\n                activation_scale_param = named_parameters[activation_scale_mlc_name]\n                transform = activation_transform_func or weight_transform_func\n                expected_shape = tuple(int(dim) for dim in activation_scale_param.shape)\n\n                def _activation_scale_transform(*arrays, dtype: str, _transform=transform):\n                    result = _transform(*arrays, dtype=dtype)\n                    result = np.asarray(result, dtype=dtype)\n                    if result.shape == expected_shape:\n                        return result\n                    if result.shape == ():\n                        # HF checkpoint stores a single scale; broadcast across the expected\n                        # dimension.\n                        return np.full(expected_shape, result.item(), dtype=dtype)\n                    if result.shape == (1,) and expected_shape != (1,):\n                        return np.broadcast_to(result, expected_shape).astype(dtype)\n                    if (\n                        result.ndim == 1\n                        and result.size > 1\n                        and len(expected_shape) >= 1\n                        and expected_shape[0] % result.size == 0\n                    ):\n                        rows_per_segment = expected_shape[0] // result.size\n                        tiled = np.repeat(result, rows_per_segment)\n                        return tiled.reshape(expected_shape).astype(dtype)\n                    raise ValueError(\n                        f\"Unexpected activation scale shape {result.shape} for \"\n                        f\"{activation_scale_mlc_name}, expected {expected_shape}\"\n                    )\n\n                mapping.add_mapping(\n                    activation_scale_mlc_name,\n                    activation_scale_hf_names,\n                    functools.partial(\n                        _activation_scale_transform, dtype=activation_scale_param.dtype\n                    ),\n                )\n\n    def identity_transform(param: np.ndarray, dtype: str):\n        return param.astype(dtype)\n\n    def make_shared_activation_transform(target_name: str):\n        def func(first: np.ndarray, *rest: np.ndarray, dtype: str):\n            for _, arr in enumerate(rest, start=1):\n                if not np.allclose(arr, first):\n                    raise ValueError(\n                        f\"Activation scales for {target_name} must be identical between \"\n                        \"concatenated sources.\"\n                    )\n            return first.astype(dtype)\n\n        return func\n\n    for i in range(model_config.num_hidden_layers):\n        # Add QKV in self attention\n        attn = f\"model.layers.{i}.self_attn\"\n        mlc_name = f\"{attn}.qkv_proj.weight\"\n        proj_sources = [hf(f\"{attn}.{proj}.weight\") for proj in [\"q_proj\", \"k_proj\", \"v_proj\"]]\n        add_weight_and_scale_mapping(\n            mlc_name,\n            proj_sources,\n            lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),\n            activation_transform_func=make_shared_activation_transform(\n                f\"{mlc_name}_activation_scale\"\n            ),\n        )\n\n        # Add gates in MLP\n        mlp = f\"model.layers.{i}.mlp\"\n        mlc_name = f\"{mlp}.gate_up_proj.weight\"\n        gate_sources = [hf(f\"{mlp}.{proj}.weight\") for proj in [\"gate_proj\", \"up_proj\"]]\n        add_weight_and_scale_mapping(\n            mlc_name,\n            gate_sources,\n            lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),\n            activation_transform_func=make_shared_activation_transform(\n                f\"{mlc_name}_activation_scale\"\n            ),\n        )\n\n        for linear_name in [f\"{attn}.o_proj.weight\", f\"{mlp}.down_proj.weight\"]:\n            add_weight_and_scale_mapping(\n                linear_name,\n                [hf(linear_name)],\n                identity_transform,\n            )\n\n        # inv_freq is not used in the model\n        mapping.add_unused(f\"{attn}.rotary_emb.inv_freq\")\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [hf(mlc_name)],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/ministral3/ministral3_model.py",
    "content": "\"\"\"\nImplementation for Ministral 3 architecture.\n\"\"\"\n\nimport dataclasses\nimport math\nfrom functools import partial\nfrom typing import Any, Dict, Optional, Tuple\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass Ministral3Config(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Ministral 3 model.\"\"\"\n\n    hidden_size: int\n    intermediate_size: int\n    num_attention_heads: int\n    num_hidden_layers: int\n    rms_norm_eps: float\n    vocab_size: int\n    attention_sink_size: int = 0\n    context_window_size: int = 0\n    dtype: str = \"float32\"\n    head_dim: int = 0\n    hidden_act: str = \"silu\"\n    max_batch_size: int = 1\n    num_key_value_heads: int = 0\n    position_embedding_base: int = 0\n    prefill_chunk_size: int = 0\n    rope_parameters: Optional[Dict[str, Any]] = None\n    sliding_window_size: int = 0\n    tensor_parallel_shards: int = 1\n    tie_word_embeddings: bool = False\n    weight_block_size: Optional[Tuple[int, int]] = None\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n    modules_to_not_convert: Tuple[str, ...] = dataclasses.field(default_factory=tuple)\n\n    @classmethod\n    def from_dict(  # type: ignore[override]\n        cls,\n        source: Dict[str, Any],\n    ) -> \"Ministral3Config\":\n        if \"text_config\" in source and isinstance(source[\"text_config\"], dict):\n            top_level = dict(source)\n            text_cfg = top_level.pop(\"text_config\")\n            merged: Dict[str, Any] = dict(top_level)\n            merged.update(text_cfg)\n            if \"tie_word_embeddings\" in source:\n                merged[\"tie_word_embeddings\"] = source[\"tie_word_embeddings\"]\n            if \"dtype\" in source:\n                merged[\"dtype\"] = source[\"dtype\"]\n            return super().from_dict(merged)\n        return super().from_dict(source)\n\n    def __post_init__(self):  # pylint: disable=too-many-branches,too-many-statements\n        if \"quantization_config\" in self.kwargs:\n            quantization_config = self.kwargs.pop(\"quantization_config\")\n            if isinstance(quantization_config, dict):\n                activation_scheme = quantization_config.get(\"activation_scheme\", \"\")\n                quant_method = quantization_config.get(\"quant_method\", \"\")\n                weight_block_size = quantization_config.get(\"weight_block_size\")\n                modules_to_not_convert = quantization_config.get(\"modules_to_not_convert\", [])\n                if isinstance(modules_to_not_convert, list):\n                    self.modules_to_not_convert = tuple(modules_to_not_convert)\n                if quant_method == \"fp8\" and activation_scheme == \"static\":\n                    if weight_block_size is not None:\n                        self.weight_block_size = weight_block_size\n                        if (\n                            not isinstance(self.weight_block_size, (tuple, list))\n                            or len(self.weight_block_size) != 2\n                        ):\n                            raise ValueError(\n                                \"Invalid Ministral3 quantization config: \",\n                                \"weight_block_size must be a list or tuple of two integers, \",\n                                f\"got {self.weight_block_size} of type\",\n                                f\"{type(self.weight_block_size)}\",\n                            )\n                    else:\n                        # Set default block size if not provided.\n                        self.weight_block_size = (128, 128)\n                        logger.info(  # pylint: disable=logging-too-many-args\n                            \"Setting default weight_block_size=%s, \",\n                            \"since quantization_config does not provide \",\n                            \"FP8 block-scale details required by \",\n                            \"MLC (activation_scheme=%s, quant_method=%s)\",\n                            self.weight_block_size,\n                            activation_scheme,\n                            quant_method,\n                        )\n                else:\n                    raise ValueError(\n                        \"Invalid Ministral 3 model quantization config: \",\n                        \"only FP8 static quantization is supported, \",\n                        f\"got activation_scheme={activation_scheme}, quant_method={quant_method}\",\n                    )\n            else:\n                raise ValueError(\n                    \"Invalid Ministral 3 model quantization config: \",\n                    \"unrecognized quantization config: \",\n                    f\"{quantization_config}\",\n                )\n\n        if self.position_embedding_base == 0:\n            if self.rope_parameters is not None and \"rope_theta\" in self.rope_parameters:\n                self.position_embedding_base = self.rope_parameters.pop(\"rope_theta\")\n            elif \"rope_theta\" in self.kwargs:\n                self.position_embedding_base = self.kwargs.pop(\"rope_theta\")\n            else:\n                self.position_embedding_base = 10000\n        if self.sliding_window_size == 0:\n            self.sliding_window_size = self.kwargs.pop(\"sliding_window\", -1)\n        if self.sliding_window_size is None:\n            # Sliding window is disabled.\n            self.sliding_window_size = -1\n        if self.context_window_size == 0:\n            if self.sliding_window_size == -1:\n                for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                    if name in self.kwargs:\n                        self.context_window_size = self.kwargs.pop(name)\n                        logger.info(\n                            \"%s not found in config.json. Falling back to %s (%d)\",\n                            bold(\"context_window_size\"),\n                            bold(name),\n                            self.context_window_size,\n                        )\n                        break\n                else:\n                    raise ValueError(\n                        \"Unable to determine the maximum sequence length, because none of \"\n                        \"`context_window_size`, `max_position_embeddings` or \"\n                        \"`max_sequence_length` is provided in `config.json`.\"\n                    )\n            else:\n                self.context_window_size = -1\n\n        if self.num_key_value_heads == 0:\n            self.num_key_value_heads = self.num_attention_heads\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.num_attention_heads % self.num_key_value_heads == 0\n        assert self.attention_sink_size >= 0\n        if self.prefill_chunk_size == 0:\n            prefill_chunk_size_candidates = []\n            if self.sliding_window_size != -1:\n                prefill_chunk_size_candidates.append(self.sliding_window_size)\n            if self.context_window_size != -1:\n                prefill_chunk_size_candidates.append(self.context_window_size)\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(*prefill_chunk_size_candidates, 8192),\n            )\n            self.prefill_chunk_size = min(*prefill_chunk_size_candidates, 8192)\n\n\nACT2FN = {\n    \"gelu\": partial(nn.gelu, approximate=False),\n    \"relu\": nn.relu,\n    \"silu\": nn.silu,\n    \"swish\": nn.silu,\n    \"gelu_new\": partial(nn.gelu, approximate=True),\n}\n\n\nclass Ministral3Embedding(nn.Embedding):\n    \"\"\"The embedding module specialized for Ministral3 so that\n    it can be shared with the final lm_head.\n    \"\"\"\n\n    def lm_head_forward(self, x: nn.Tensor):\n        \"\"\"The lm_head forwarding, which transposes the weight and multiplies\n        with the input tensor.\n        \"\"\"\n        weight = nn.op.permute_dims(self.weight)\n        return nn.op.matmul(x, weight, out_dtype=\"float32\")\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass Ministral3MLP(nn.Module):\n    \"\"\"Same as in Llama architecture (LlamaFFN).\"\"\"\n\n    def __init__(self, config: Ministral3Config):\n        super().__init__()\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.gate_up_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=2 * self.intermediate_size,\n            bias=False,\n        )\n        self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x: Tensor):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(self.act_fn(x1) * x2)\n\n\ndef yarn_get_sm_scale(scale=1, mscale=1):\n    if scale <= 1:\n        return 1.0\n    return 0.1 * mscale * math.log(scale) + 1.0\n\n\nclass Ministral3Attention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Same as LlamaAttention, but with sliding window attention using a rolling buffer cache.\"\"\"\n\n    def __init__(self, config: Ministral3Config):\n        self.head_dim = config.head_dim\n        if config.num_key_value_heads % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.num_key_value_heads} key-value attention heads \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards\n        self.num_kv_heads = config.num_key_value_heads // config.tensor_parallel_shards\n        self.qkv_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim,\n            bias=False,\n        )\n        self.o_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False)\n\n        self.softmax_scale = self.head_dim ** (-0.5)\n        if config.rope_parameters is not None:\n            mscale_all_dim = config.rope_parameters.get(\"mscale_all_dim\", 0)\n            scaling_factor = config.rope_parameters[\"factor\"]\n            if mscale_all_dim:\n                sm_scale = yarn_get_sm_scale(scaling_factor, mscale_all_dim)\n                self.softmax_scale = self.softmax_scale * sm_scale * sm_scale\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads\n        b, s, _ = hidden_states.shape\n        # QKV Projection\n        qkv = self.qkv_proj(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        # Attention\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_q_heads, sm_scale=self.softmax_scale\n            ),\n            (b, s, h_q * d),\n        )\n        return self.o_proj(output)\n\n\nclass Ministral3DecoderLayer(nn.Module):\n    \"\"\"Exact same as LlamaDecoderLayer.\"\"\"\n\n    def __init__(self, config: Ministral3Config):\n        rms_norm_eps = config.rms_norm_eps\n        self.self_attn = Ministral3Attention(config)\n        self.mlp = Ministral3MLP(config)\n        self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)\n        self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.weight.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_q_heads * hd\n            k = self.self_attn.num_kv_heads * hd\n            v = self.self_attn.num_kv_heads * hd\n            i = self.mlp.intermediate_size\n            _set(\n                self.self_attn.qkv_proj,\n                tp.ShardSingleDim(\"_shard_qkv\", segs=[q, k, v], dim=0),\n            )\n            _set(self.self_attn.o_proj, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.gate_up_proj,\n                tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n            )\n            _set(self.mlp.down_proj, tp.ShardSingleDim(\"_shard_mlp_down\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.mlp(self.post_attention_layernorm(hidden_states))\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass Ministral3Model(nn.Module):\n    \"\"\"Exact same as LlamaModel.\"\"\"\n\n    def __init__(self, config: Ministral3Config):\n        assert config.hidden_size % config.num_attention_heads == 0\n        # self.embed_tokens = nn.Embedding(\"vocab_size\", config.hidden_size)\n        self.embed_tokens = Ministral3Embedding(config.vocab_size, config.hidden_size)\n        self.layers = nn.ModuleList(\n            [Ministral3DecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n\n    def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = input_embed\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass Mistral3ForConditionalGeneration(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: Ministral3Config):\n        self.model = Ministral3Model(config)\n        self.tie_word_embeddings = config.tie_word_embeddings\n        if not config.tie_word_embeddings:\n            self.lm_head = nn.Linear(\n                config.hidden_size, config.vocab_size, bias=False\n            )  # \"vocab_size\"\n        self._mark_modules_no_quant(config.modules_to_not_convert)\n        self.num_hidden_layers = config.num_hidden_layers\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.head_dim = config.head_dim\n        self.hidden_size = config.hidden_size\n        self.vocab_size = config.vocab_size\n        self.rope_theta = config.position_embedding_base\n        self.rope_parameters = config.rope_parameters\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.sliding_window_size = config.sliding_window_size\n        self.dtype = config.dtype\n        self.weight_block_size = config.weight_block_size\n\n    def _mark_modules_no_quant(self, modules: Tuple[str, ...]):\n        for path in modules:\n            if not path:\n                continue\n            parts = path.split(\".\")\n            target = self\n            for part in parts:\n                if not hasattr(target, part):\n                    target = None\n                    break\n                target = getattr(target, part)\n            if target is not None:\n                setattr(target, \"no_quantization\", True)\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n\n        if self.tie_word_embeddings:\n            logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n\n        if self.tie_word_embeddings:\n            logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n\n        if self.tie_word_embeddings:\n            logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            rope_scaling=self.rope_parameters,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/mistral/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/mistral/mistral_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Mistral parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\nfrom mlc_llm.quantization import Quantization, make_awq_quant\n\nfrom .mistral_model import MistralConfig, MistralForCausalLM\n\nawq_quant = make_awq_quant(MistralForCausalLM)\n\n\nhuggingface = make_standard_hf_loader(\n    model_cls=MistralForCausalLM,\n    add_unused=[\"rotary_emb.inv_freq\"],\n)\n\n\ndef awq(model_config: MistralConfig, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of AWQ parameters.\n    Parameters\n    ----------\n    model_config : MistralConfig\n        The configuration of the Mistral model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to AWQ.\n    \"\"\"\n    model, _ = awq_quant(model_config, quantization)\n    _, _named_params = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),  # type: ignore[attr-defined]\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    for i in range(model_config.num_hidden_layers):\n        # Add QKV in self attention\n        attn = f\"model.layers.{i}.self_attn\"\n        for quantize_suffix in [\"qweight\", \"qzeros\", \"scales\"]:\n            mlc_name = f\"{attn}.qkv_proj.{quantize_suffix}\"\n            assert mlc_name in named_parameters\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{attn}.q_proj.{quantize_suffix}\",\n                    f\"{attn}.k_proj.{quantize_suffix}\",\n                    f\"{attn}.v_proj.{quantize_suffix}\",\n                ],\n                functools.partial(\n                    lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n        # Concat gate and up in MLP\n        mlp = f\"model.layers.{i}.mlp\"\n        for quantize_suffix in [\"qweight\", \"qzeros\", \"scales\"]:\n            mlc_name = f\"{mlp}.gate_up_proj.{quantize_suffix}\"\n            assert mlc_name in named_parameters\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{mlp}.gate_proj.{quantize_suffix}\",\n                    f\"{mlp}.up_proj.{quantize_suffix}\",\n                ],\n                functools.partial(\n                    lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n        # inv_freq is not used in the model\n        mapping.add_unused(f\"{attn}.rotary_emb.inv_freq\")\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype),\n            )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/mistral/mistral_model.py",
    "content": "\"\"\"\nImplementation for Mistral architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass MistralConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Mistral model.\"\"\"\n\n    hidden_size: int\n    intermediate_size: int\n    num_attention_heads: int\n    num_hidden_layers: int\n    rms_norm_eps: float\n    vocab_size: int\n    position_embedding_base: int = 0\n    num_key_value_heads: int = 0\n    head_dim: int = 0\n    context_window_size: int = 0\n    sliding_window_size: int = 0\n    prefill_chunk_size: int = 0\n    attention_sink_size: int = 4\n    tensor_parallel_shards: int = 1\n    max_batch_size: int = 1\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):  # pylint: disable=too-many-branches\n        if self.position_embedding_base == 0:\n            if \"rope_theta\" in self.kwargs:\n                self.position_embedding_base = self.kwargs.pop(\"rope_theta\")\n            else:\n                self.position_embedding_base = 10000\n        if self.sliding_window_size == 0:\n            self.sliding_window_size = self.kwargs.pop(\"sliding_window\", -1)\n        if self.sliding_window_size is None:\n            # Sliding window is disabled.\n            self.sliding_window_size = -1\n        if self.context_window_size == 0:\n            if self.sliding_window_size == -1:\n                for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                    if name in self.kwargs:\n                        self.context_window_size = self.kwargs.pop(name)\n                        logger.info(\n                            \"%s not found in config.json. Falling back to %s (%d)\",\n                            bold(\"context_window_size\"),\n                            bold(name),\n                            self.context_window_size,\n                        )\n                        break\n                else:\n                    raise ValueError(\n                        \"Unable to determine the maximum sequence length, because none of \"\n                        \"`context_window_size`, `max_position_embeddings` or \"\n                        \"`max_sequence_length` is provided in `config.json`.\"\n                    )\n            else:\n                self.context_window_size = -1\n\n        if self.num_key_value_heads == 0:\n            self.num_key_value_heads = self.num_attention_heads\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.num_attention_heads % self.num_key_value_heads == 0\n        assert self.attention_sink_size >= 0\n        if self.prefill_chunk_size == 0:\n            prefill_chunk_size_candidates = []\n            if self.sliding_window_size != -1:\n                prefill_chunk_size_candidates.append(self.sliding_window_size)\n            if self.context_window_size != -1:\n                prefill_chunk_size_candidates.append(self.context_window_size)\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(*prefill_chunk_size_candidates, 8192),\n            )\n            self.prefill_chunk_size = min(*prefill_chunk_size_candidates, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass MistralMLP(nn.Module):\n    \"\"\"Same as in Llama architecture (LlamaFFN).\"\"\"\n\n    def __init__(self, config: MistralConfig):\n        super().__init__()\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.gate_up_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=2 * self.intermediate_size,\n            bias=False,\n        )\n        self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)\n\n    def forward(self, x: Tensor):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(op.silu(x1) * x2)\n\n\nclass MistralAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Same as LlamaAttention, but with sliding window attention using a rolling buffer cache.\"\"\"\n\n    def __init__(self, config: MistralConfig):\n        self.head_dim = config.head_dim\n        if config.num_key_value_heads % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.num_key_value_heads} key-value attention heads \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards\n        self.num_kv_heads = config.num_key_value_heads // config.tensor_parallel_shards\n        self.qkv_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim,\n            bias=False,\n        )\n        self.o_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False)\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads\n        b, s, _ = hidden_states.shape\n        # QKV Projection\n        qkv = self.qkv_proj(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        # Attention\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_q_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        return self.o_proj(output)\n\n\nclass MistralDecoderLayer(nn.Module):\n    \"\"\"Exact same as LlamaDecoderLayer.\"\"\"\n\n    def __init__(self, config: MistralConfig):\n        rms_norm_eps = config.rms_norm_eps\n        self.self_attn = MistralAttention(config)\n        self.mlp = MistralMLP(config)\n        self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)\n        self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.weight.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_q_heads * hd\n            k = self.self_attn.num_kv_heads * hd\n            v = self.self_attn.num_kv_heads * hd\n            i = self.mlp.intermediate_size\n            _set(\n                self.self_attn.qkv_proj,\n                tp.ShardSingleDim(\"_shard_qkv\", segs=[q, k, v], dim=0),\n            )\n            _set(self.self_attn.o_proj, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.gate_up_proj,\n                tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n            )\n            _set(self.mlp.down_proj, tp.ShardSingleDim(\"_shard_mlp_down\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.mlp(self.post_attention_layernorm(hidden_states))\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass MistralModel(nn.Module):\n    \"\"\"Exact same as LlamaModel.\"\"\"\n\n    def __init__(self, config: MistralConfig):\n        assert config.hidden_size % config.num_attention_heads == 0\n        self.embed_tokens = nn.Embedding(\"vocab_size\", config.hidden_size)\n        self.layers = nn.ModuleList(\n            [MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n\n    def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = input_embed\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass MistralForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Same as LlamaForCausalLM, except for the use of sliding window attention.\"\"\"\n\n    def __init__(self, config: MistralConfig):\n        self.model = MistralModel(config)\n        self.lm_head = nn.Linear(config.hidden_size, \"vocab_size\", bias=False)\n        self.num_hidden_layers = config.num_hidden_layers\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.head_dim = config.head_dim\n        self.hidden_size = config.hidden_size\n        self.vocab_size = config.vocab_size\n        self.rope_theta = config.position_embedding_base\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.sliding_window_size = config.sliding_window_size\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/mixtral/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/mixtral/mixtral_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Mixtral parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.quantization import Quantization\n\nfrom .mixtral_model import MixtralConfig, MixtralForCausalLM\n\n\ndef huggingface(model_config: MixtralConfig, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : MixtralConfig\n        The configuration of the Mixtral model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = MixtralForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    for i in range(model_config.num_hidden_layers):\n        # Add QKV in self attention\n        attn = f\"model.layers.{i}.self_attn\"\n        mlc_name = f\"{attn}.qkv_proj.weight\"\n        mlc_param = named_parameters[mlc_name]\n        mapping.add_mapping(\n            mlc_name,\n            [\n                f\"{attn}.q_proj.weight\",\n                f\"{attn}.k_proj.weight\",\n                f\"{attn}.v_proj.weight\",\n            ],\n            functools.partial(\n                lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n        # Add gates in MLP (when MoE is enabled)\n        mlp = f\"model.layers.{i}.block_sparse_moe\"\n        mlc_mlp = f\"model.layers.{i}.moe\"\n        mlc_name = f\"{mlc_mlp}.e1_e3.weight\"\n        mlc_param = named_parameters[mlc_name]\n\n        def combine_expert_gate_up(*hf_params, dtype):\n            stack = []\n            for i in range(0, len(hf_params), 2):\n                stack.append(np.concatenate([hf_params[i], hf_params[i + 1]], axis=0))\n            return np.stack(stack, axis=0).astype(dtype)\n\n        mapping.add_mapping(\n            mlc_name,\n            functools.reduce(\n                lambda a, b: a + b,\n                [\n                    [\n                        f\"{mlp}.experts.{expert_id}.w1.weight\",\n                        f\"{mlp}.experts.{expert_id}.w3.weight\",\n                    ]\n                    for expert_id in range(model_config.num_local_experts)\n                ],\n            ),\n            functools.partial(\n                combine_expert_gate_up,\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n        mlc_name = f\"{mlc_mlp}.e2.weight\"\n        mlc_param = named_parameters[mlc_name]\n        mapping.add_mapping(\n            mlc_name,\n            [\n                f\"{mlp}.experts.{expert_id}.w2.weight\"\n                for expert_id in range(model_config.num_local_experts)\n            ],\n            functools.partial(\n                lambda *hf_params, dtype: np.stack(hf_params, axis=0).astype(dtype),\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n        mlc_name = f\"{mlc_mlp}.gate.weight\"\n        mlc_param = named_parameters[mlc_name]\n        mapping.add_mapping(\n            mlc_name,\n            [f\"{mlp}.gate.weight\"],\n            functools.partial(\n                lambda x, dtype: x.astype(dtype),\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n        # inv_freq is not used in the model\n        mapping.add_unused(f\"{attn}.rotary_emb.inv_freq\")\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/mixtral/mixtral_model.py",
    "content": "\"\"\"Implementation for Mistral architecture.\"\"\"\n\nimport dataclasses\n\nfrom tvm import tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.model.llama.llama_model import (\n    LlamaAttention,\n    LlamaConfig,\n    LlamaForCausalLM,\n    LlamaModel,\n)\nfrom mlc_llm.nn import PagedKVCache\nfrom mlc_llm.nn.expert import MixtralExperts\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass MixtralConfig(LlamaConfig):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Mixtral model.\"\"\"\n\n    num_local_experts: int = 0\n    num_experts_per_tok: int = 0\n\n\n# pylint: disable=invalid-name,missing-docstring,too-many-locals,fixme\n\n\nclass MixtralMoE(nn.Module):\n    \"\"\"Mixture of experts\"\"\"\n\n    def __init__(self, config: MixtralConfig):\n        super().__init__()\n        self.num_experts_per_tok = config.num_experts_per_tok\n        self.num_local_experts = config.num_local_experts\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MoE intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.gate = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=config.num_local_experts,\n            bias=False,\n        )\n        self.e1_e3 = MixtralExperts(\n            self.num_local_experts,\n            in_features=config.hidden_size,\n            out_features=2 * self.intermediate_size,\n            tensor_parallel_shards=config.tensor_parallel_shards,\n        )\n        self.e2 = MixtralExperts(\n            self.num_local_experts,\n            in_features=self.intermediate_size,\n            out_features=config.hidden_size,\n            tensor_parallel_shards=config.tensor_parallel_shards,\n        )\n        self.dtype = \"float32\"\n\n    def forward(self, x: Tensor):\n        def _expert_forward(x: Tensor, indptr: Tensor):\n            x1_x3 = self.e1_e3(x, indptr)\n            x1, x3 = op.split(x1_x3, indices_or_sections=2, axis=-1)\n            x = self.e2(op.silu(x1) * x3, indptr)\n            return x\n\n        experts_per_tok = self.num_experts_per_tok  # activated experts per token\n        local_experts = self.num_local_experts  # total number of experts\n        batch_size, seq_len, hidden_size = x.shape\n        num_tokens = batch_size * seq_len\n        x = x.reshape(num_tokens, hidden_size)\n        # gate: [num_tokens, local_experts]\n        gate: Tensor = self.gate(x)\n        # expert_weights: [num_tokens, experts_per_tok]\n        # expert_indices: [num_tokens, experts_per_tok]\n        expert_weights, expert_indices = op_ext.moe_misc.gating_softmax_topk(gate, experts_per_tok)\n        use_ft = (\n            op_ext.get_store().cutlass_group_gemm or op_ext.get_store().faster_transformer\n        ) and self.dtype == \"float16\"\n        if num_tokens == 1:\n            # x: [num_tokens * experts_per_tok, hidden_size]\n            x = _expert_forward(x, expert_indices)\n        else:\n            # cumsum: [num_tokens * local_experts]\n            cumsum = op_ext.moe_misc.moe_cumsum(expert_indices, local_experts)\n            # indices: [num_tokens * experts_per_tok]\n            reverse_indices, token_indices = op_ext.moe_misc.get_indices(cumsum, expert_indices)\n            if use_ft:\n                # indptr: [num_local_experts]\n                indptr = op_ext.moe_misc.get_indptr(\n                    cumsum, local_experts, num_tokens, inclusive=True, out_dtype=\"int64\"\n                )\n            else:\n                # indptr: [num_local_experts + 1]\n                indptr = op_ext.moe_misc.get_indptr(\n                    cumsum,\n                    local_experts,\n                    num_tokens,\n                    inclusive=False,\n                    out_dtype=\"int32\",\n                )\n            # x: [num_tokens * experts_per_tok, hidden_size]\n            x = op.take(x, token_indices, axis=0)\n            x = _expert_forward(x, indptr)\n            x = op_ext.moe_misc.scatter_output(x, reverse_indices)\n        # x: [num_tokens, experts_per_tok, hidden_size]\n        x = x.reshape(  # pylint: disable=too-many-function-args\n            num_tokens, experts_per_tok, hidden_size\n        ) * expert_weights.reshape(  # pylint: disable=too-many-function-args\n            num_tokens, experts_per_tok, 1\n        )\n        # x: [num_tokens, hidden_size]\n        x = op_ext.moe_misc.moe_sum(x, dim=1)\n        x = x.reshape(batch_size, seq_len, hidden_size)  # pylint: disable=too-many-function-args\n        return x\n\n\nclass MixtralDecoderLayer(nn.Module):\n    \"\"\"Mixtral decoder layer\"\"\"\n\n    def __init__(self, config: MixtralConfig):\n        eps = config.rms_norm_eps\n        self.self_attn = LlamaAttention(config)\n        self.moe = MixtralMoE(config)\n        self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, eps, bias=False)\n        self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, -1, eps, bias=False)\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.weight.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_q_heads * hd\n            k = self.self_attn.num_kv_heads * hd\n            v = self.self_attn.num_kv_heads * hd\n            i = self.moe.intermediate_size\n            _set(\n                self.self_attn.qkv_proj,\n                tp.ShardSingleDim(\"_shard_qkv\", segs=[q, k, v], dim=0),\n            )\n            _set(self.self_attn.o_proj, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(self.moe.e1_e3, tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=1))\n            _set(self.moe.e2, tp.ShardSingleDim(\"_shard_mlp_down\", dim=2))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, attention_mask: Tensor, total_seq_len: tir.Var):\n        \"\"\"Forward pass of a decoder layer; calculate attention, and add an residual connection.\"\"\"\n        out = self.self_attn(self.input_layernorm(hidden_states), attention_mask, total_seq_len)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.moe(self.post_attention_layernorm(hidden_states))\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def batch_forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.moe(self.post_attention_layernorm(hidden_states))\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass MixtralModel(LlamaModel):\n    \"\"\"Exact same as LlamaModel.\"\"\"\n\n    def __init__(self, config: MixtralConfig):\n        super().__init__(config)\n        self.layers = nn.ModuleList(\n            [MixtralDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n\n\nclass MixtralForCausalLM(LlamaForCausalLM):\n    \"\"\"Same as LlamaForCausalLM.\"\"\"\n\n    def __init__(self, config: MixtralConfig):\n        super().__init__(config)\n        self.model = MixtralModel(config)\n"
  },
  {
    "path": "python/mlc_llm/model/model.py",
    "content": "\"\"\"A centralized registry of all existing model architures and their configurations.\"\"\"\n\nimport dataclasses\nfrom typing import Any, Callable, Dict, Literal, Optional, Tuple\n\nfrom tvm.relax.frontend import nn\n\nfrom mlc_llm.loader import ExternMapping, QuantizeMapping\nfrom mlc_llm.quantization import make_quantization_functions\nfrom mlc_llm.quantization.quantization import Quantization\n\nfrom .baichuan import baichuan_loader, baichuan_model\nfrom .bert import bert_loader, bert_model\nfrom .chatglm3 import chatglm3_loader, chatglm3_model\nfrom .cohere import cohere_loader, cohere_model\nfrom .deepseek import deepseek_loader, deepseek_model\nfrom .deepseek_v2 import deepseek_v2_loader, deepseek_v2_model\nfrom .eagle import eagle_loader, eagle_model\nfrom .gemma import gemma_loader, gemma_model\nfrom .gemma2 import gemma2_loader, gemma2_model\nfrom .gemma3 import gemma3_loader, gemma3_model\nfrom .gpt2 import gpt2_loader, gpt2_model\nfrom .gpt_bigcode import gpt_bigcode_loader, gpt_bigcode_model\nfrom .gpt_j import gpt_j_loader, gpt_j_model\nfrom .gpt_neox import gpt_neox_loader, gpt_neox_model\nfrom .internlm import internlm_loader, internlm_model\nfrom .internlm2 import internlm2_loader, internlm2_model\nfrom .llama import llama_loader, llama_model\nfrom .llama4 import llama4_loader, llama4_model\nfrom .llava import llava_loader, llava_model\nfrom .medusa import medusa_loader, medusa_model\nfrom .minicpm import minicpm_loader, minicpm_model\nfrom .ministral3 import ministral3_loader, ministral3_model\nfrom .mistral import mistral_loader, mistral_model\nfrom .mixtral import mixtral_loader, mixtral_model\nfrom .nemotron import nemotron_loader, nemotron_model\nfrom .olmo import olmo_loader, olmo_model\nfrom .orion import orion_loader, orion_model\nfrom .phi import phi_loader, phi_model\nfrom .phi3 import phi3_loader, phi3_model\nfrom .phi3v import phi3v_loader, phi3v_model\nfrom .qwen import qwen_loader, qwen_model\nfrom .qwen2 import qwen2_loader, qwen2_model\nfrom .qwen2_moe import qwen2_moe_loader, qwen2_moe_model\nfrom .qwen3 import qwen3_loader, qwen3_model\nfrom .qwen3_moe import qwen3_moe_loader, qwen3_moe_model\nfrom .rwkv5 import rwkv5_loader, rwkv5_model\nfrom .rwkv6 import rwkv6_loader, rwkv6_model\nfrom .stable_lm import stablelm_loader, stablelm_model\nfrom .starcoder2 import starcoder2_loader, starcoder2_model\n\nModelConfig = Any\n\"\"\"A ModelConfig is an object that represents a model architecture. It is required to have\na class method `from_file` with the following signature:\n\n    def from_file(cls, path: Path) -> ModelConfig:\n        ...\n\"\"\"\n\nFuncGetExternMap = Callable[[ModelConfig, Quantization], ExternMapping]\nFuncQuantization = Callable[[ModelConfig, Quantization], Tuple[nn.Module, QuantizeMapping]]\n\n\n@dataclasses.dataclass\nclass EmbeddingMetadata:\n    \"\"\"Embedding model metadata.\n\n    Parameters\n    ----------\n    model_type: Literal[\"encoder\", \"decoder\"]\n        The type of the embedding model.\n\n    pooling_strategy: Literal[\"cls\", \"mean\", \"last\"]\n        The pooling strategy to use for the embedding model.\n\n    normalize: bool = True\n        Default to normalize the embedding.\n    \"\"\"\n\n    model_type: Literal[\"encoder\", \"decoder\"]\n    pooling_strategy: Literal[\"cls\", \"mean\", \"last\"]\n    normalize: bool = True\n\n\n@dataclasses.dataclass\nclass Model:\n    \"\"\"All about a model architecture: its configuration, its parameter loader and quantization.\n\n    Parameters\n    ----------\n    name : str\n        The name of the model.\n\n    model : Callable[[ModelConfig], nn.Module]\n        A method that creates the `nn.Module` that represents the model from `ModelConfig`.\n\n    config : ModelConfig\n        A class that has a `from_file` class method, whose signature is \"Path -> ModelConfig\".\n\n    source : Dict[str, FuncGetExternMap]\n        A dictionary that maps the name of a source format to parameter mapping.\n\n    quantize: Dict[str, FuncQuantization]\n        A dictionary that maps the name of a quantization method to quantized model and the\n        quantization parameter mapping.\n\n    model_task: Literal[\"chat\", \"embedding\"] = \"chat\"\n        A task of the model to distinguish between chat and embedding models. Default to \"chat\".\n\n    embedding_metadata: Optional[EmbeddingMetadata] = None\n        Metadata for the embedding model. Default to None.\n    \"\"\"\n\n    name: str\n    config: ModelConfig\n    model: Callable[[ModelConfig], nn.Module]\n    source: Dict[str, FuncGetExternMap]\n    quantize: Dict[str, FuncQuantization]\n\n    model_task: Literal[\"chat\", \"embedding\"] = \"chat\"\n    embedding_metadata: Optional[EmbeddingMetadata] = None\n\n    def __post_init__(self):\n        if self.model_task == \"embedding\" and self.embedding_metadata is None:\n            raise ValueError(f\"[Model] {self.name}: Embedding model must have embedding metadata.\")\n        if self.model_task == \"chat\" and self.embedding_metadata is not None:\n            raise ValueError(\n                f\"[Model] {self.name}: Chat model not expected to have embedding metadata.\"\n            )\n\n\nMODELS: Dict[str, Model] = {\n    \"llama\": Model(\n        name=\"llama\",\n        model=llama_model.LlamaForCausalLM,\n        config=llama_model.LlamaConfig,\n        source={\n            \"huggingface-torch\": llama_loader.huggingface,\n            \"huggingface-safetensor\": llama_loader.huggingface,\n            \"awq\": llama_loader.awq,\n        },\n        quantize=make_quantization_functions(\n            llama_model.LlamaForCausalLM,\n            supports_awq=True,\n            supports_per_tensor=True,\n        ),\n    ),\n    \"llama4\": Model(\n        name=\"llama4\",\n        model=llama4_model.Llama4ForCausalLM,\n        config=llama4_model.Llama4Config,\n        source={\n            \"huggingface-torch\": llama4_loader.huggingface,\n            \"huggingface-safetensor\": llama4_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            llama4_model.Llama4ForCausalLM,\n            supports_per_tensor=True,\n        ),\n    ),\n    \"mistral\": Model(\n        name=\"mistral\",\n        model=mistral_model.MistralForCausalLM,\n        config=mistral_model.MistralConfig,\n        source={\n            \"huggingface-torch\": mistral_loader.huggingface,\n            \"huggingface-safetensor\": mistral_loader.huggingface,\n            \"awq\": mistral_loader.awq,\n        },\n        quantize=make_quantization_functions(\n            mistral_model.MistralForCausalLM,\n        ),\n    ),\n    \"ministral3\": Model(\n        name=\"ministral3\",\n        model=ministral3_model.Mistral3ForConditionalGeneration,\n        config=ministral3_model.Ministral3Config,\n        source={\n            \"huggingface-torch\": ministral3_loader.huggingface,\n            \"huggingface-safetensor\": ministral3_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            ministral3_model.Mistral3ForConditionalGeneration,\n            supports_block_scale=True,\n        ),\n    ),\n    \"gemma\": Model(\n        name=\"gemma\",\n        model=gemma_model.GemmaForCausalLM,\n        config=gemma_model.GemmaConfig,\n        source={\n            \"huggingface-torch\": gemma_loader.huggingface,\n            \"huggingface-safetensor\": gemma_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            gemma_model.GemmaForCausalLM,\n            supports_ft_quant=False,\n        ),\n    ),\n    \"gemma2\": Model(\n        name=\"gemma2\",\n        model=gemma2_model.Gemma2ForCausalLM,\n        config=gemma2_model.Gemma2Config,\n        source={\n            \"huggingface-torch\": gemma2_loader.huggingface,\n            \"huggingface-safetensor\": gemma2_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            gemma2_model.Gemma2ForCausalLM,\n            supports_ft_quant=False,\n        ),\n    ),\n    \"gemma3\": Model(\n        name=\"gemma3\",\n        model=gemma3_model.Gemma3ForCausalLM,\n        config=gemma3_model.Gemma3Config,\n        source={\n            \"huggingface-torch\": gemma3_loader.huggingface,\n            \"huggingface-safetensor\": gemma3_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            gemma3_model.Gemma3ForCausalLM,\n            supports_ft_quant=False,\n        ),\n    ),\n    \"gemma3_text\": Model(\n        name=\"gemma3_text\",\n        model=gemma3_model.Gemma3ForCausalLM,\n        config=gemma3_model.Gemma3Config,\n        source={\n            \"huggingface-torch\": gemma3_loader.huggingface,\n            \"huggingface-safetensor\": gemma3_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            gemma3_model.Gemma3ForCausalLM,\n            supports_ft_quant=False,\n        ),\n    ),\n    \"gpt2\": Model(\n        name=\"gpt2\",\n        model=gpt2_model.GPT2LMHeadModel,\n        config=gpt2_model.GPT2Config,\n        source={\n            \"huggingface-torch\": gpt2_loader.huggingface,\n            \"huggingface-safetensor\": gpt2_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            gpt2_model.GPT2LMHeadModel,\n        ),\n    ),\n    \"mixtral\": Model(\n        name=\"mixtral\",\n        model=mixtral_model.MixtralForCausalLM,\n        config=mixtral_model.MixtralConfig,\n        source={\n            \"huggingface-torch\": mixtral_loader.huggingface,\n            \"huggingface-safetensor\": mixtral_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            mixtral_model.MixtralForCausalLM,\n            supports_awq=True,\n            awq_unsupported_message=\"AWQ is not implemented for Mixtral models.\",\n            supports_per_tensor=True,\n        ),\n    ),\n    \"gpt_neox\": Model(\n        name=\"gpt_neox\",\n        model=gpt_neox_model.GPTNeoXForCausalLM,\n        config=gpt_neox_model.GPTNeoXConfig,\n        source={\n            \"huggingface-torch\": gpt_neox_loader.huggingface,\n            \"huggingface-safetensor\": gpt_neox_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            gpt_neox_model.GPTNeoXForCausalLM,\n        ),\n    ),\n    \"gpt_bigcode\": Model(\n        name=\"gpt_bigcode\",\n        model=gpt_bigcode_model.GPTBigCodeForCausalLM,\n        config=gpt_bigcode_model.GPTBigCodeConfig,\n        source={\n            \"huggingface-torch\": gpt_bigcode_loader.huggingface,\n            \"huggingface-safetensor\": gpt_bigcode_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            gpt_bigcode_model.GPTBigCodeForCausalLM,\n        ),\n    ),\n    \"phi-msft\": Model(\n        name=\"phi-msft\",\n        model=phi_model.PhiForCausalLM,\n        config=phi_model.PhiConfig,\n        source={\n            \"huggingface-torch\": phi_loader.huggingface,\n            \"huggingface-safetensor\": phi_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            phi_model.PhiForCausalLM,\n        ),\n    ),\n    \"phi\": Model(\n        name=\"phi\",\n        model=phi_model.PhiForCausalLM,\n        config=phi_model.Phi1Config,\n        source={\n            \"huggingface-torch\": phi_loader.phi1_huggingface,\n            \"huggingface-safetensor\": phi_loader.phi1_huggingface,\n        },\n        quantize=make_quantization_functions(\n            phi_model.PhiForCausalLM,\n        ),\n    ),\n    \"phi3\": Model(\n        name=\"phi3\",\n        model=phi3_model.Phi3ForCausalLM,\n        config=phi3_model.Phi3Config,\n        source={\n            \"huggingface-torch\": phi3_loader.phi3_huggingface,\n            \"huggingface-safetensor\": phi3_loader.phi3_huggingface,\n        },\n        quantize=make_quantization_functions(\n            phi3_model.Phi3ForCausalLM,\n        ),\n    ),\n    \"phi3_v\": Model(\n        name=\"phi3_v\",\n        model=phi3v_model.Phi3VForCausalLM,\n        config=phi3v_model.Phi3VConfig,\n        source={\n            \"huggingface-torch\": phi3v_loader.huggingface,\n            \"huggingface-safetensor\": phi3v_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            phi3v_model.Phi3VForCausalLM,\n        ),\n    ),\n    \"qwen\": Model(\n        name=\"qwen\",\n        model=qwen_model.QWenLMHeadModel,\n        config=qwen_model.QWenConfig,\n        source={\n            \"huggingface-torch\": qwen_loader.huggingface,\n            \"huggingface-safetensor\": qwen_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            qwen_model.QWenLMHeadModel,\n        ),\n    ),\n    \"qwen2\": Model(\n        name=\"qwen2\",\n        model=qwen2_model.QWen2LMHeadModel,\n        config=qwen2_model.QWen2Config,\n        source={\n            \"huggingface-torch\": qwen2_loader.huggingface,\n            \"huggingface-safetensor\": qwen2_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            qwen2_model.QWen2LMHeadModel,\n        ),\n    ),\n    \"qwen2_moe\": Model(\n        name=\"qwen2_moe\",\n        model=qwen2_moe_model.Qwen2MoeForCausalLM,\n        config=qwen2_moe_model.Qwen2MoeConfig,\n        source={\n            \"huggingface-torch\": qwen2_moe_loader.huggingface,\n            \"huggingface-safetensor\": qwen2_moe_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            qwen2_moe_model.Qwen2MoeForCausalLM,\n        ),\n    ),\n    \"qwen3\": Model(\n        name=\"qwen3\",\n        model=qwen3_model.Qwen3LMHeadModel,\n        config=qwen3_model.Qwen3Config,\n        source={\n            \"huggingface-torch\": qwen3_loader.huggingface,\n            \"huggingface-safetensor\": qwen3_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            qwen3_model.Qwen3LMHeadModel,\n            supports_block_scale=True,\n        ),\n    ),\n    \"qwen3-embedding\": Model(\n        name=\"qwen3-embedding\",\n        model=qwen3_model.Qwen3EmbeddingModel,\n        config=qwen3_model.Qwen3Config,\n        source={\n            \"huggingface-torch\": qwen3_loader.huggingface_embedding,\n            \"huggingface-safetensor\": qwen3_loader.huggingface_embedding,\n        },\n        quantize=make_quantization_functions(\n            qwen3_model.Qwen3EmbeddingModel,\n            supports_block_scale=True,\n        ),\n        model_task=\"embedding\",\n        embedding_metadata=EmbeddingMetadata(\n            model_type=\"decoder\",\n            pooling_strategy=\"last\",\n            normalize=True,\n        ),\n    ),\n    \"qwen3_moe\": Model(\n        name=\"qwen3_moe\",\n        model=qwen3_moe_model.Qwen3MoeForCausalLM,\n        config=qwen3_moe_model.Qwen3MoeConfig,\n        source={\n            \"huggingface-torch\": qwen3_moe_loader.huggingface,\n            \"huggingface-safetensor\": qwen3_moe_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            qwen3_moe_model.Qwen3MoeForCausalLM,\n            supports_block_scale=True,\n        ),\n    ),\n    \"deepseek_v2\": Model(\n        name=\"deepseek_v2\",\n        model=deepseek_v2_model.DeepseekV2ForCausalLM,\n        config=deepseek_v2_model.DeepseekV2Config,\n        source={\n            \"huggingface-torch\": deepseek_v2_loader.huggingface,\n            \"huggingface-safetensor\": deepseek_v2_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            deepseek_v2_model.DeepseekV2ForCausalLM,\n        ),\n    ),\n    \"deepseek_v3\": Model(\n        name=\"deepseek_v3\",\n        model=deepseek_v2_model.DeepseekV2ForCausalLM,\n        config=deepseek_v2_model.DeepseekV2Config,\n        source={\n            \"huggingface-torch\": deepseek_v2_loader.huggingface,\n            \"huggingface-safetensor\": deepseek_v2_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            deepseek_v2_model.DeepseekV2ForCausalLM,\n            supports_block_scale=True,\n        ),\n    ),\n    \"stablelm\": Model(\n        name=\"stablelm\",\n        model=stablelm_model.StableLmForCausalLM,\n        config=stablelm_model.StableLmConfig,\n        source={\n            \"huggingface-torch\": stablelm_loader.huggingface,\n            \"huggingface-safetensor\": stablelm_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            stablelm_model.StableLmForCausalLM,\n        ),\n    ),\n    \"baichuan\": Model(\n        name=\"baichuan\",\n        model=baichuan_model.BaichuanForCausalLM,\n        config=baichuan_model.BaichuanConfig,\n        source={\n            \"huggingface-torch\": baichuan_loader.huggingface,\n            \"huggingface-safetensor\": baichuan_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            baichuan_model.BaichuanForCausalLM,\n        ),\n    ),\n    \"internlm\": Model(\n        name=\"internlm\",\n        model=internlm_model.InternLMForCausalLM,\n        config=internlm_model.InternLMConfig,\n        source={\n            \"huggingface-torch\": internlm_loader.huggingface,\n            \"huggingface-safetensor\": internlm_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            internlm_model.InternLMForCausalLM,\n        ),\n    ),\n    \"internlm2\": Model(\n        name=\"internlm2\",\n        model=internlm2_model.InternLM2ForCausalLM,\n        config=internlm2_model.InternLM2Config,\n        source={\n            \"huggingface-torch\": internlm2_loader.huggingface,\n            \"huggingface-safetensor\": internlm2_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            internlm2_model.InternLM2ForCausalLM,\n        ),\n    ),\n    \"rwkv5\": Model(\n        name=\"rwkv5\",\n        model=rwkv5_model.RWKV5_ForCausalLM,\n        config=rwkv5_model.RWKV5Config,\n        source={\n            \"huggingface-torch\": rwkv5_loader.huggingface,\n            \"huggingface-safetensor\": rwkv5_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            rwkv5_model.RWKV5_ForCausalLM,\n        ),\n    ),\n    \"orion\": Model(\n        name=\"orion\",\n        model=orion_model.OrionForCausalLM,\n        config=orion_model.OrionConfig,\n        source={\n            \"huggingface-torch\": orion_loader.huggingface,\n            \"huggingface-safetensor\": orion_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            orion_model.OrionForCausalLM,\n            supports_ft_quant=False,\n        ),\n    ),\n    \"llava\": Model(\n        name=\"llava\",\n        model=llava_model.LlavaForCausalLM,\n        config=llava_model.LlavaConfig,\n        source={\n            \"huggingface-torch\": llava_loader.huggingface,\n            \"huggingface-safetensor\": llava_loader.huggingface,\n            \"awq\": llava_loader.awq,\n        },\n        quantize=make_quantization_functions(\n            llava_model.LlavaForCausalLM,\n            supports_awq=True,\n            supports_ft_quant=False,\n        ),\n    ),\n    \"rwkv6\": Model(\n        name=\"rwkv6\",\n        model=rwkv6_model.RWKV6_ForCausalLM,\n        config=rwkv6_model.RWKV6Config,\n        source={\n            \"huggingface-torch\": rwkv6_loader.huggingface,\n            \"huggingface-safetensor\": rwkv6_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            rwkv6_model.RWKV6_ForCausalLM,\n            supports_ft_quant=False,\n        ),\n    ),\n    \"chatglm\": Model(\n        name=\"chatglm\",\n        model=chatglm3_model.ChatGLMForCausalLM,\n        config=chatglm3_model.GLMConfig,\n        source={\n            \"huggingface-torch\": chatglm3_loader.huggingface,\n            \"huggingface-safetensor\": chatglm3_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            chatglm3_model.ChatGLMForCausalLM,\n            supports_ft_quant=False,\n        ),\n    ),\n    \"eagle\": Model(\n        name=\"eagle\",\n        model=eagle_model.EagleForCausalLM,\n        config=eagle_model.EagleConfig,\n        source={\n            \"huggingface-torch\": eagle_loader.huggingface,\n            \"huggingface-safetensor\": eagle_loader.huggingface,\n            \"awq\": eagle_loader.awq,\n        },\n        quantize=make_quantization_functions(\n            eagle_model.EagleForCausalLM,\n            supports_awq=True,\n        ),\n    ),\n    \"bert\": Model(\n        name=\"bert\",\n        model=bert_model.BertModel,\n        config=bert_model.BertConfig,\n        source={\n            \"huggingface-torch\": bert_loader.huggingface,\n            \"huggingface-safetensor\": bert_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            bert_model.BertModel,\n        ),\n        model_task=\"embedding\",\n        embedding_metadata=EmbeddingMetadata(\n            model_type=\"encoder\",\n            pooling_strategy=\"cls\",\n            normalize=True,\n        ),\n    ),\n    \"medusa\": Model(\n        name=\"medusa\",\n        model=medusa_model.MedusaModel,\n        config=medusa_model.MedusaConfig,\n        source={\n            \"huggingface-torch\": medusa_loader.huggingface,\n            \"huggingface-safetensor\": medusa_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            medusa_model.MedusaModel,\n            supports_group_quant=False,\n            supports_ft_quant=False,\n        ),\n    ),\n    \"starcoder2\": Model(\n        name=\"starcoder2\",\n        model=starcoder2_model.Starcoder2ForCausalLM,\n        config=starcoder2_model.Starcoder2Config,\n        source={\n            \"huggingface-torch\": starcoder2_loader.huggingface,\n            \"huggingface-safetensor\": starcoder2_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            starcoder2_model.Starcoder2ForCausalLM,\n        ),\n    ),\n    \"cohere\": Model(\n        name=\"cohere\",\n        model=cohere_model.CohereForCausalLM,\n        config=cohere_model.CohereConfig,\n        source={\n            \"huggingface-torch\": cohere_loader.huggingface,\n            \"huggingface-safetensor\": cohere_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            cohere_model.CohereForCausalLM,\n        ),\n    ),\n    \"minicpm\": Model(\n        name=\"minicpm\",\n        model=minicpm_model.MiniCPMForCausalLM,\n        config=minicpm_model.MiniCPMConfig,\n        source={\n            \"huggingface-torch\": minicpm_loader.huggingface,\n            \"huggingface-safetensor\": minicpm_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            minicpm_model.MiniCPMForCausalLM,\n        ),\n    ),\n    \"deepseek\": Model(\n        name=\"deepseek\",\n        model=deepseek_model.DeepseekForCausalLM,\n        config=deepseek_model.DeepseekConfig,\n        source={\n            \"huggingface-torch\": deepseek_loader.huggingface,\n            \"huggingface-safetensor\": deepseek_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            deepseek_model.DeepseekForCausalLM,\n        ),\n    ),\n    \"gptj\": Model(\n        name=\"gptj\",\n        model=gpt_j_model.GPTJForCausalLM,\n        config=gpt_j_model.GPTJConfig,\n        source={\n            \"huggingface-torch\": gpt_j_loader.huggingface,\n            \"huggingface-safetensor\": gpt_j_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            gpt_j_model.GPTJForCausalLM,\n        ),\n    ),\n    \"olmo\": Model(\n        name=\"olmo\",\n        model=olmo_model.OLMoForCausalLM,\n        config=olmo_model.OLMoConfig,\n        source={\n            \"huggingface-torch\": olmo_loader.huggingface,\n            \"huggingface-safetensor\": olmo_loader.huggingface,\n            \"awq\": olmo_loader.awq,\n        },\n        quantize=make_quantization_functions(\n            olmo_model.OLMoForCausalLM,\n            supports_awq=True,\n            supports_per_tensor=True,\n        ),\n    ),\n    \"nemotron\": Model(\n        name=\"nemotron\",\n        model=nemotron_model.NemotronForCausalLM,\n        config=nemotron_model.NemotronConfig,\n        source={\n            \"huggingface-torch\": nemotron_loader.huggingface,\n            \"huggingface-safetensor\": nemotron_loader.huggingface,\n        },\n        quantize=make_quantization_functions(\n            nemotron_model.NemotronForCausalLM,\n            supports_awq=True,\n            supports_per_tensor=True,\n        ),\n    ),\n    \"bert-bge\": Model(\n        name=\"bert-bge\",\n        model=bert_model.BertModel,\n        config=bert_model.BertConfig,\n        source={\n            \"huggingface-torch\": bert_loader.huggingface_bge,\n            \"huggingface-safetensor\": bert_loader.huggingface_bge,\n        },\n        quantize=make_quantization_functions(\n            bert_model.BertModel,\n        ),\n        model_task=\"embedding\",\n        embedding_metadata=EmbeddingMetadata(\n            model_type=\"encoder\",\n            pooling_strategy=\"cls\",\n            normalize=True,\n        ),\n    ),\n}\n"
  },
  {
    "path": "python/mlc_llm/model/model_preset.py",
    "content": "\"\"\"A builtin set of models available in MLC LLM.\"\"\"\n\nfrom typing import Any, Dict  # pylint: disable=too-many-lines\n\n# pylint: disable=too-many-lines\n\nMODEL_PRESETS: Dict[str, Any] = {\n    \"llama2_7b\": {\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 4096,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 11008,\n        \"max_position_embeddings\": 2048,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 32,\n        \"num_key_value_heads\": 32,\n        \"pad_token_id\": 0,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": None,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"float16\",\n        \"transformers_version\": \"4.31.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": 32000,\n        \"context_window_size\": 2048,\n        \"prefill_chunk_size\": 2048,\n    },\n    \"llama2_13b\": {\n        \"_name_or_path\": \"meta-llama/Llama-2-13b-hf\",\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 5120,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 13824,\n        \"max_position_embeddings\": 2048,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": 40,\n        \"num_hidden_layers\": 40,\n        \"num_key_value_heads\": 40,\n        \"pad_token_id\": 0,\n        \"pretraining_tp\": 2,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": None,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"float16\",\n        \"transformers_version\": \"4.31.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": 32000,\n        \"context_window_size\": 2048,\n        \"prefill_chunk_size\": 2048,\n    },\n    \"llama2_70b\": {\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 8192,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 28672,\n        \"max_position_embeddings\": 2048,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": 64,\n        \"num_hidden_layers\": 80,\n        \"num_key_value_heads\": 8,\n        \"pad_token_id\": 0,\n        \"rms_norm_eps\": 1e-05,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"float16\",\n        \"transformers_version\": \"4.31.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": 32000,\n        \"context_window_size\": 2048,\n        \"prefill_chunk_size\": 2048,\n    },\n    \"codellama_7b\": {\n        \"_name_or_path\": \"codellama/CodeLlama-7b-hf\",\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 4096,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 11008,\n        \"max_position_embeddings\": 16384,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 32,\n        \"num_key_value_heads\": 32,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": None,\n        \"rope_theta\": 1000000,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.33.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": 32016,\n        \"context_window_size\": 2048,\n        \"prefill_chunk_size\": 2048,\n    },\n    \"codellama_13b\": {\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 5120,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 13824,\n        \"max_position_embeddings\": 16384,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": 40,\n        \"num_hidden_layers\": 40,\n        \"num_key_value_heads\": 40,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": None,\n        \"rope_theta\": 1000000,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.32.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": 32016,\n        \"context_window_size\": 2048,\n        \"prefill_chunk_size\": 2048,\n    },\n    \"codellama_34b\": {\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 8192,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 22016,\n        \"max_position_embeddings\": 16384,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": 64,\n        \"num_hidden_layers\": 48,\n        \"num_key_value_heads\": 8,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": None,\n        \"rope_theta\": 1000000,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.32.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": 32016,\n        \"context_window_size\": 2048,\n        \"prefill_chunk_size\": 2048,\n    },\n    \"tinyllama_1b_chat_v0.4\": {\n        \"_name_or_path\": \"/data/tianduo/tinyllama-ft/checkpoint-3890\",\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 2048,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 5632,\n        \"max_position_embeddings\": 2048,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 22,\n        \"num_key_value_heads\": 4,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": None,\n        \"rope_theta\": 10000.0,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"float32\",\n        \"transformers_version\": \"4.33.1\",\n        \"use_cache\": False,\n        \"vocab_size\": 32003,\n    },\n    \"tinyllama_1b_chat_v1.0\": {\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"attention_bias\": False,\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 2048,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 5632,\n        \"max_position_embeddings\": 2048,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 22,\n        \"num_key_value_heads\": 4,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": None,\n        \"rope_theta\": 10000.0,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.35.0\",\n        \"use_cache\": True,\n        \"vocab_size\": 32000,\n    },\n    \"mistral_7b\": {\n        \"architectures\": [\"MistralForCausalLM\"],\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 4096,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 14336,\n        \"max_position_embeddings\": 32768,\n        \"model_type\": \"mistral\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 32,\n        \"num_key_value_heads\": 8,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_theta\": 10000.0,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.34.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": 32000,\n        \"sliding_window_size\": 4096,\n        \"prefill_chunk_size\": 128,\n        \"attention_sink_size\": 4,\n    },\n    \"mistral_7b_v03\": {\n        \"architectures\": [\"MistralForCausalLM\"],\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 4096,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 14336,\n        \"max_position_embeddings\": 32768,\n        \"model_type\": \"mistral\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 32,\n        \"num_key_value_heads\": 8,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_theta\": 1000000.0,\n        \"sliding_window\": None,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.42.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": 32768,\n    },\n    \"ministral3_3b_2512\": {\n        \"architectures\": [\"Mistral3ForConditionalGeneration\"],\n        \"dtype\": \"bfloat16\",\n        \"image_token_index\": 10,\n        \"model_type\": \"ministral3\",\n        \"multimodal_projector_bias\": False,\n        \"projector_hidden_act\": \"gelu\",\n        \"spatial_merge_size\": 2,\n        \"text_config\": {\n            \"attention_dropout\": 0.0,\n            \"head_dim\": 128,\n            \"hidden_act\": \"silu\",\n            \"hidden_size\": 3072,\n            \"initializer_range\": 0.02,\n            \"intermediate_size\": 9216,\n            \"max_position_embeddings\": 262144,\n            \"model_type\": \"ministral3\",\n            \"num_attention_heads\": 32,\n            \"num_hidden_layers\": 26,\n            \"num_key_value_heads\": 8,\n            \"rms_norm_eps\": 1e-05,\n            \"rope_parameters\": {\n                \"beta_fast\": 32.0,\n                \"beta_slow\": 1.0,\n                \"factor\": 16.0,\n                \"llama_4_scaling_beta\": 0.1,\n                \"mscale\": 1.0,\n                \"mscale_all_dim\": 1.0,\n                \"original_max_position_embeddings\": 16384,\n                \"rope_theta\": 1000000.0,\n                \"rope_type\": \"yarn\",\n                \"type\": \"yarn\",\n            },\n            \"sliding_window\": None,\n            \"tie_word_embeddings\": True,\n            \"use_cache\": True,\n            \"vocab_size\": 131072,\n        },\n        \"transformers_version\": \"5.0.0.dev0\",\n        \"vision_config\": {\n            \"attention_dropout\": 0.0,\n            \"head_dim\": 64,\n            \"hidden_act\": \"silu\",\n            \"hidden_size\": 1024,\n            \"image_size\": 1540,\n            \"initializer_range\": 0.02,\n            \"intermediate_size\": 4096,\n            \"model_type\": \"pixtral\",\n            \"num_attention_heads\": 16,\n            \"num_channels\": 3,\n            \"num_hidden_layers\": 24,\n            \"patch_size\": 14,\n            \"rope_parameters\": {\"rope_theta\": 10000.0, \"rope_type\": \"default\"},\n            \"rope_theta\": 10000.0,\n        },\n        \"vision_feature_layer\": -1,\n    },\n    \"gpt2\": {\n        \"activation_function\": \"gelu_new\",\n        \"architectures\": [\"GPT2LMHeadModel\"],\n        \"attn_pdrop\": 0.1,\n        \"bos_token_id\": 50256,\n        \"embd_pdrop\": 0.1,\n        \"eos_token_id\": 50256,\n        \"initializer_range\": 0.02,\n        \"layer_norm_epsilon\": 1e-05,\n        \"model_type\": \"gpt2\",\n        \"n_ctx\": 1024,\n        \"n_embd\": 768,\n        \"n_head\": 12,\n        \"n_layer\": 12,\n        \"n_positions\": 1024,\n        \"resid_pdrop\": 0.1,\n        \"summary_activation\": None,\n        \"summary_first_dropout\": 0.1,\n        \"summary_proj_to_labels\": True,\n        \"summary_type\": \"cls_index\",\n        \"summary_use_proj\": True,\n        \"task_specific_params\": {\"text-generation\": {\"do_sample\": True, \"max_length\": 50}},\n        \"vocab_size\": 50257,\n    },\n    \"gpt2_medium\": {\n        \"activation_function\": \"gelu_new\",\n        \"architectures\": [\"GPT2LMHeadModel\"],\n        \"attn_pdrop\": 0.1,\n        \"bos_token_id\": 50256,\n        \"embd_pdrop\": 0.1,\n        \"eos_token_id\": 50256,\n        \"initializer_range\": 0.02,\n        \"layer_norm_epsilon\": 1e-05,\n        \"model_type\": \"gpt2\",\n        \"n_ctx\": 1024,\n        \"n_embd\": 1024,\n        \"n_head\": 16,\n        \"n_layer\": 24,\n        \"n_positions\": 1024,\n        \"n_special\": 0,\n        \"predict_special_tokens\": True,\n        \"resid_pdrop\": 0.1,\n        \"summary_activation\": None,\n        \"summary_first_dropout\": 0.1,\n        \"summary_proj_to_labels\": True,\n        \"summary_type\": \"cls_index\",\n        \"summary_use_proj\": True,\n        \"task_specific_params\": {\"text-generation\": {\"do_sample\": True, \"max_length\": 50}},\n        \"vocab_size\": 50257,\n    },\n    \"gpt_bigcode\": {\n        \"activation_function\": \"gelu_pytorch_tanh\",\n        \"architectures\": [\"GPTBigCodeForCausalLM\"],\n        \"attention_softmax_in_fp32\": True,\n        \"multi_query\": True,\n        \"attn_pdrop\": 0.1,\n        \"bos_token_id\": 49152,\n        \"embd_pdrop\": 0.1,\n        \"eos_token_id\": 49152,\n        \"initializer_range\": 0.02,\n        \"layer_norm_epsilon\": 1e-05,\n        \"model_type\": \"gpt_bigcode\",\n        \"n_embd\": 2048,\n        \"n_head\": 16,\n        \"n_inner\": 8192,\n        \"n_layer\": 24,\n        \"n_positions\": 2048,\n        \"resid_pdrop\": 0.1,\n        \"runner_max_sequence_length\": None,\n        \"scale_attention_softmax_in_fp32\": True,\n        \"scale_attn_weights\": True,\n        \"summary_activation\": None,\n        \"summary_first_dropout\": 0.1,\n        \"summary_proj_to_labels\": True,\n        \"summary_type\": \"cls_index\",\n        \"summary_use_proj\": True,\n        \"transformers_version\": \"4.28.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": 49280,\n    },\n    \"Mixtral-8x7B-v0.1\": {\n        \"architectures\": [\"MixtralForCausalLM\"],\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 4096,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 14336,\n        \"max_position_embeddings\": 32768,\n        \"model_type\": \"mixtral\",\n        \"num_attention_heads\": 32,\n        \"num_experts_per_tok\": 2,\n        \"num_hidden_layers\": 32,\n        \"num_key_value_heads\": 8,\n        \"num_local_experts\": 8,\n        \"output_router_logits\": False,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_theta\": 1000000.0,\n        \"router_aux_loss_coef\": 0.02,\n        \"sliding_window\": None,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.36.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": 32000,\n    },\n    \"redpajama_3b_v1\": {\n        \"_name_or_path\": \"/root/fm/models/rp_3b_800b_real_fp16\",\n        \"architectures\": [\"GPTNeoXForCausalLM\"],\n        \"bos_token_id\": 0,\n        \"eos_token_id\": 0,\n        \"hidden_act\": \"gelu\",\n        \"hidden_size\": 2560,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 10240,\n        \"layer_norm_eps\": 1e-05,\n        \"max_position_embeddings\": 2048,\n        \"model_type\": \"gpt_neox\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 32,\n        \"rotary_emb_base\": 10000,\n        \"rotary_pct\": 1.0,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"float16\",\n        \"transformers_version\": \"4.28.1\",\n        \"use_cache\": True,\n        \"use_parallel_residual\": False,\n        \"vocab_size\": 50432,\n    },\n    \"phi-1_5\": {\n        \"_name_or_path\": \"microsoft/phi-1_5\",\n        \"activation_function\": \"gelu_new\",\n        \"architectures\": [\"PhiForCausalLM\"],\n        \"attn_pdrop\": 0.0,\n        \"auto_map\": {\n            \"AutoConfig\": \"configuration_phi.PhiConfig\",\n            \"AutoModelForCausalLM\": \"modeling_phi.PhiForCausalLM\",\n        },\n        \"embd_pdrop\": 0.0,\n        \"flash_attn\": False,\n        \"flash_rotary\": False,\n        \"fused_dense\": False,\n        \"initializer_range\": 0.02,\n        \"layer_norm_epsilon\": 1e-05,\n        \"model_type\": \"phi-msft\",\n        \"n_embd\": 2048,\n        \"n_head\": 32,\n        \"n_head_kv\": None,\n        \"n_inner\": None,\n        \"n_layer\": 24,\n        \"n_positions\": 2048,\n        \"resid_pdrop\": 0.0,\n        \"rotary_dim\": 32,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"float16\",\n        \"transformers_version\": \"4.34.1\",\n        \"vocab_size\": 51200,\n    },\n    \"phi-2\": {\n        \"_name_or_path\": \"microsoft/phi-2\",\n        \"activation_function\": \"gelu_new\",\n        \"architectures\": [\"PhiForCausalLM\"],\n        \"attn_pdrop\": 0.0,\n        \"auto_map\": {\n            \"AutoConfig\": \"configuration_phi.PhiConfig\",\n            \"AutoModelForCausalLM\": \"modeling_phi.PhiForCausalLM\",\n        },\n        \"embd_pdrop\": 0.0,\n        \"flash_attn\": False,\n        \"flash_rotary\": False,\n        \"fused_dense\": False,\n        \"img_processor\": None,\n        \"initializer_range\": 0.02,\n        \"layer_norm_epsilon\": 1e-05,\n        \"model_type\": \"phi-msft\",\n        \"n_embd\": 2560,\n        \"n_head\": 32,\n        \"n_head_kv\": None,\n        \"n_inner\": None,\n        \"n_layer\": 32,\n        \"n_positions\": 2048,\n        \"resid_pdrop\": 0.1,\n        \"rotary_dim\": 32,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"float16\",\n        \"transformers_version\": \"4.35.2\",\n        \"vocab_size\": 51200,\n    },\n    \"phi-3_5\": {\n        \"_name_or_path\": \"Phi-3.5-mini-instruct\",\n        \"architectures\": [\"Phi3ForCausalLM\"],\n        \"attention_dropout\": 0.0,\n        \"auto_map\": {\n            \"AutoConfig\": \"configuration_phi3.Phi3Config\",\n            \"AutoModelForCausalLM\": \"modeling_phi3.Phi3ForCausalLM\",\n        },\n        \"bos_token_id\": 1,\n        \"embd_pdrop\": 0.0,\n        \"eos_token_id\": 32000,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 3072,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 8192,\n        \"max_position_embeddings\": 131072,\n        \"model_type\": \"phi3\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 32,\n        \"num_key_value_heads\": 32,\n        \"original_max_position_embeddings\": 4096,\n        \"pad_token_id\": 32000,\n        \"resid_pdrop\": 0.0,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": {\n            \"long_factor\": [\n                1.0800000429153442,\n                1.1100000143051147,\n                1.1399999856948853,\n                1.340000033378601,\n                1.5899999141693115,\n                1.600000023841858,\n                1.6200000047683716,\n                2.620000123977661,\n                3.2300000190734863,\n                3.2300000190734863,\n                4.789999961853027,\n                7.400000095367432,\n                7.700000286102295,\n                9.09000015258789,\n                12.199999809265137,\n                17.670000076293945,\n                24.46000099182129,\n                28.57000160217285,\n                30.420001983642578,\n                30.840002059936523,\n                32.590003967285156,\n                32.93000411987305,\n                42.320003509521484,\n                44.96000289916992,\n                50.340003967285156,\n                50.45000457763672,\n                57.55000305175781,\n                57.93000411987305,\n                58.21000289916992,\n                60.1400032043457,\n                62.61000442504883,\n                62.62000274658203,\n                62.71000289916992,\n                63.1400032043457,\n                63.1400032043457,\n                63.77000427246094,\n                63.93000411987305,\n                63.96000289916992,\n                63.970001220703125,\n                64.02999877929688,\n                64.06999969482422,\n                64.08000183105469,\n                64.12000274658203,\n                64.41000366210938,\n                64.4800033569336,\n                64.51000213623047,\n                64.52999877929688,\n                64.83999633789062,\n            ],\n            \"short_factor\": [\n                1.0,\n                1.0199999809265137,\n                1.0299999713897705,\n                1.0299999713897705,\n                1.0499999523162842,\n                1.0499999523162842,\n                1.0499999523162842,\n                1.0499999523162842,\n                1.0499999523162842,\n                1.0699999332427979,\n                1.0999999046325684,\n                1.1099998950958252,\n                1.1599998474121094,\n                1.1599998474121094,\n                1.1699998378753662,\n                1.2899998426437378,\n                1.339999794960022,\n                1.679999828338623,\n                1.7899998426437378,\n                1.8199998140335083,\n                1.8499997854232788,\n                1.8799997568130493,\n                1.9099997282028198,\n                1.9399996995925903,\n                1.9899996519088745,\n                2.0199997425079346,\n                2.0199997425079346,\n                2.0199997425079346,\n                2.0199997425079346,\n                2.0199997425079346,\n                2.0199997425079346,\n                2.0299997329711914,\n                2.0299997329711914,\n                2.0299997329711914,\n                2.0299997329711914,\n                2.0299997329711914,\n                2.0299997329711914,\n                2.0299997329711914,\n                2.0299997329711914,\n                2.0299997329711914,\n                2.0799996852874756,\n                2.0899996757507324,\n                2.189999580383301,\n                2.2199995517730713,\n                2.5899994373321533,\n                2.729999542236328,\n                2.749999523162842,\n                2.8399994373321533,\n            ],\n            \"type\": \"longrope\",\n        },\n        \"rope_theta\": 10000.0,\n        \"sliding_window\": 262144,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.43.3\",\n        \"use_cache\": True,\n        \"attention_bias\": False,\n        \"vocab_size\": 32064,\n    },\n    \"phi-3_5-vision\": {\n        \"_name_or_path\": \"Phi-3.5-vision-instruct\",\n        \"architectures\": [\"Phi3VForCausalLM\"],\n        \"attention_dropout\": 0.0,\n        \"auto_map\": {\n            \"AutoConfig\": \"configuration_phi3_v.Phi3VConfig\",\n            \"AutoModelForCausalLM\": \"modeling_phi3_v.Phi3VForCausalLM\",\n        },\n        \"bos_token_id\": 1,\n        \"embd_layer\": {\n            \"embedding_cls\": \"image\",\n            \"hd_transform_order\": \"sub_glb\",\n            \"projection_cls\": \"mlp\",\n            \"use_hd_transform\": True,\n            \"with_learnable_separator\": True,\n        },\n        \"embd_pdrop\": 0.0,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 3072,\n        \"img_processor\": {\n            \"image_dim_out\": 1024,\n            \"model_name\": \"openai/clip-vit-large-patch14-336\",\n            \"name\": \"clip_vision_model\",\n            \"num_img_tokens\": 144,\n        },\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 8192,\n        \"max_position_embeddings\": 131072,\n        \"model_type\": \"phi3_v\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 32,\n        \"num_key_value_heads\": 32,\n        \"original_max_position_embeddings\": 4096,\n        \"pad_token_id\": 32000,\n        \"resid_pdrop\": 0.0,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": {\n            \"long_factor\": [\n                1.0800000429153442,\n                1.1100000143051147,\n                1.1399999856948853,\n                1.340000033378601,\n                1.5899999141693115,\n                1.600000023841858,\n                1.6200000047683716,\n                2.620000123977661,\n                3.2300000190734863,\n                3.2300000190734863,\n                4.789999961853027,\n                7.400000095367432,\n                7.700000286102295,\n                9.09000015258789,\n                12.199999809265137,\n                17.670000076293945,\n                24.46000099182129,\n                28.57000160217285,\n                30.420001983642578,\n                30.840002059936523,\n                32.590003967285156,\n                32.93000411987305,\n                42.320003509521484,\n                44.96000289916992,\n                50.340003967285156,\n                50.45000457763672,\n                57.55000305175781,\n                57.93000411987305,\n                58.21000289916992,\n                60.1400032043457,\n                62.61000442504883,\n                62.62000274658203,\n                62.71000289916992,\n                63.1400032043457,\n                63.1400032043457,\n                63.77000427246094,\n                63.93000411987305,\n                63.96000289916992,\n                63.970001220703125,\n                64.02999877929688,\n                64.06999969482422,\n                64.08000183105469,\n                64.12000274658203,\n                64.41000366210938,\n                64.4800033569336,\n                64.51000213623047,\n                64.52999877929688,\n                64.83999633789062,\n            ],\n            \"short_factor\": [\n                1.08,\n                1.1,\n                1.1300000000000001,\n                1.2800000000000002,\n                1.3100000000000003,\n                1.4500000000000004,\n                1.4500000000000004,\n                1.9500000000000008,\n                2.030000000000001,\n                2.4299999999999926,\n                2.5699999999999896,\n                2.9499999999999815,\n                3.729999999999965,\n                3.869999999999962,\n                4.189999999999955,\n                4.43999999999995,\n                4.6399999999999455,\n                4.979999999999938,\n                5.159999999999934,\n                5.279999999999932,\n                5.759999999999922,\n                5.889999999999919,\n                5.889999999999919,\n                5.969999999999917,\n                6.089999999999915,\n                6.2799999999999105,\n                6.7699999999999,\n                6.8899999999998975,\n                7.109999999999893,\n                7.129999999999892,\n                7.179999999999891,\n                7.289999999999889,\n                7.339999999999888,\n                7.559999999999883,\n                7.619999999999882,\n                7.69999999999988,\n                7.879999999999876,\n                7.879999999999876,\n                7.879999999999876,\n                7.939999999999875,\n                7.949999999999875,\n                7.979999999999874,\n                8.19999999999987,\n                8.439999999999864,\n                8.469999999999864,\n                8.589999999999861,\n                8.809999999999857,\n                8.999999999999853,\n            ],\n            \"type\": \"su\",\n        },\n        \"rope_theta\": 10000.0,\n        \"sliding_window\": 262144,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.38.1\",\n        \"use_cache\": True,\n        \"vocab_size\": 32064,\n        \"_attn_implementation\": \"flash_attention_2\",\n    },\n    \"phi-4\": {\n        \"_name_or_path\": \"Phi-4-mini-instruct\",\n        \"architectures\": [\"Phi3ForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"auto_map\": {\n            \"AutoConfig\": \"configuration_phi3.Phi3Config\",\n            \"AutoModelForCausalLM\": \"modeling_phi3.Phi3ForCausalLM\",\n            \"AutoTokenizer\": \"Xenova/gpt-4o\",\n        },\n        \"bos_token_id\": 199999,\n        \"embd_pdrop\": 0.0,\n        \"eos_token_id\": 199999,\n        \"full_attn_mod\": 1,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 3072,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 8192,\n        \"interpolate_factor\": 1,\n        \"lm_head_bias\": False,\n        \"max_position_embeddings\": 131072,\n        \"mlp_bias\": False,\n        \"model_type\": \"phi3\",\n        \"num_attention_heads\": 24,\n        \"num_hidden_layers\": 32,\n        \"num_key_value_heads\": 8,\n        \"original_max_position_embeddings\": 4096,\n        \"pad_token_id\": 199999,\n        \"partial_rotary_factor\": 0.75,\n        \"resid_pdrop\": 0.0,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": {\n            \"long_factor\": [\n                1,\n                1.118320672,\n                1.250641126,\n                1.398617824,\n                1.564103225,\n                1.74916897,\n                1.956131817,\n                2.187582649,\n                2.446418898,\n                2.735880826,\n                3.059592084,\n                3.421605075,\n                3.826451687,\n                4.279200023,\n                4.785517845,\n                5.351743533,\n                5.984965424,\n                6.693110555,\n                7.485043894,\n                8.370679318,\n                9.36110372,\n                10.4687158,\n                11.70738129,\n                13.09260651,\n                14.64173252,\n                16.37415215,\n                18.31155283,\n                20.47818807,\n                22.90118105,\n                25.61086418,\n                28.64115884,\n                32.03,\n                32.1,\n                32.13,\n                32.23,\n                32.6,\n                32.61,\n                32.64,\n                32.66,\n                32.7,\n                32.71,\n                32.93,\n                32.97,\n                33.28,\n                33.49,\n                33.5,\n                44.16,\n                47.77,\n            ],\n            \"short_factor\": [\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n                1.0,\n            ],\n            \"type\": \"longrope\",\n        },\n        \"rope_theta\": 10000.0,\n        \"sliding_window\": 262144,\n        \"tie_word_embeddings\": True,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.45.0\",\n        \"use_cache\": True,\n        \"vocab_size\": 200064,\n    },\n    \"qwen\": {\n        \"architectures\": [\"QWenLMHeadModel\"],\n        \"auto_map\": {\n            \"AutoConfig\": \"configuration_qwen.QWenConfig\",\n            \"AutoModelForCausalLM\": \"modeling_qwen.QWenLMHeadModel\",\n        },\n        \"attn_dropout_prob\": 0.0,\n        \"bf16\": False,\n        \"emb_dropout_prob\": 0.0,\n        \"hidden_size\": 2048,\n        \"intermediate_size\": 11008,\n        \"initializer_range\": 0.02,\n        \"kv_channels\": 128,\n        \"layer_norm_epsilon\": 1e-06,\n        \"max_position_embeddings\": 8192,\n        \"model_type\": \"qwen\",\n        \"no_bias\": True,\n        \"num_attention_heads\": 16,\n        \"num_hidden_layers\": 24,\n        \"rotary_emb_base\": 10000,\n        \"rotary_pct\": 1.0,\n        \"scale_attn_weights\": True,\n        \"seq_length\": 8192,\n        \"tie_word_embeddings\": False,\n        \"tokenizer_class\": \"QWenTokenizer\",\n        \"transformers_version\": \"4.32.0\",\n        \"use_cache\": True,\n        \"use_dynamic_ntk\": True,\n        \"use_flash_attn\": \"auto\",\n        \"use_logn_attn\": True,\n        \"vocab_size\": 151936,\n    },\n    \"qwen2\": {\n        \"_name_or_path\": \"Qwen/Qwen1.5-1.8B-Chat\",\n        \"architectures\": [\"Qwen2ForCausalLM\"],\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 151643,\n        \"eos_token_id\": 151645,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 2048,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 5504,\n        \"max_position_embeddings\": 4096,\n        \"max_window_layers\": 21,\n        \"model_type\": \"qwen2\",\n        \"num_attention_heads\": 16,\n        \"num_hidden_layers\": 24,\n        \"num_key_value_heads\": 16,\n        \"rms_norm_eps\": 1e-06,\n        \"rope_theta\": 1000000.0,\n        \"sliding_window\": 32768,\n        \"tie_word_embeddings\": True,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.37.2\",\n        \"use_cache\": True,\n        \"use_sliding_window\": False,\n        \"vocab_size\": 151936,\n    },\n    \"qwen2moe\": {\n        \"architectures\": [\"Qwen2MoeForCausalLM\"],\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 151643,\n        \"eos_token_id\": 151645,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 2048,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 5632,\n        \"max_position_embeddings\": 32768,\n        \"max_window_layers\": 21,\n        \"model_type\": \"qwen2_moe\",\n        \"num_attention_heads\": 16,\n        \"num_hidden_layers\": 24,\n        \"num_key_value_heads\": 16,\n        \"rms_norm_eps\": 1e-06,\n        \"rope_theta\": 1000000.0,\n        \"sliding_window\": 32768,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.39.0.dev0\",\n        \"use_cache\": True,\n        \"use_sliding_window\": False,\n        \"vocab_size\": 151936,\n        \"decoder_sparse_step\": 1,\n        \"moe_intermediate_size\": 1408,\n        \"shared_expert_intermediate_size\": 5632,\n        \"num_experts_per_tok\": 4,\n        \"num_experts\": 60,\n        \"norm_topk_prob\": False,\n        \"output_router_logits\": False,\n        \"router_aux_loss_coef\": 0.001,\n    },\n    \"deepseek_v2_lite\": {\n        \"architectures\": [\"DeepseekV2ForCausalLM\"],\n        \"attention_bias\": False,\n        \"bos_token_id\": 100000,\n        \"eos_token_id\": 100001,\n        \"first_k_dense_replace\": 1,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 2048,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 10944,\n        \"kv_lora_rank\": 512,\n        \"max_position_embeddings\": 163840,\n        \"model_type\": \"deepseek_v2\",\n        \"moe_intermediate_size\": 1408,\n        \"moe_layer_freq\": 1,\n        \"n_group\": 1,\n        \"n_routed_experts\": 64,\n        \"n_shared_experts\": 2,\n        \"norm_topk_prob\": False,\n        \"num_attention_heads\": 16,\n        \"num_experts_per_tok\": 6,\n        \"num_hidden_layers\": 27,\n        \"num_key_value_heads\": 16,\n        \"pretraining_tp\": 1,\n        \"qk_nope_head_dim\": 128,\n        \"qk_rope_head_dim\": 64,\n        \"rms_norm_eps\": 1e-06,\n        \"rope_scaling\": {\n            \"beta_fast\": 32,\n            \"beta_slow\": 1,\n            \"factor\": 40,\n            \"mscale\": 0.707,\n            \"mscale_all_dim\": 0.707,\n            \"original_max_position_embeddings\": 4096,\n            \"type\": \"yarn\",\n        },\n        \"rope_theta\": 10000,\n        \"routed_scaling_factor\": 1.0,\n        \"scoring_func\": \"softmax\",\n        \"topk_group\": 1,\n        \"topk_method\": \"greedy\",\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.33.1\",\n        \"use_cache\": True,\n        \"v_head_dim\": 128,\n        \"vocab_size\": 102400,\n    },\n    \"stablelm\": {\n        \"architectures\": [\"StableLmForCausalLM\"],\n        \"bos_token_id\": 0,\n        \"eos_token_id\": 0,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 2560,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 6912,\n        \"max_position_embeddings\": 4096,\n        \"model_type\": \"stablelm\",\n        \"layer_norm_eps\": 1e-05,\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 32,\n        \"num_key_value_heads\": 32,\n        \"partial_rotary_factor\": 0.25,\n        \"rope_theta\": 10000,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.38.0\",\n        \"use_cache\": True,\n        \"use_qkv_bias\": False,\n        \"vocab_size\": 50304,\n    },\n    \"baichuan\": {\n        \"architectures\": [\"BaichuanForCausalLM\"],\n        \"tokenizer_class\": \"BaichuanTokenizer\",\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_size\": 4096,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 11008,\n        \"max_position_embeddings\": 4096,\n        \"model_max_length\": 4096,\n        \"model_type\": \"baichuan\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 32,\n        \"pad_token_id\": 0,\n        \"rms_norm_eps\": 1e-06,\n        \"_from_model_config\": True,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.29.2\",\n        \"use_cache\": True,\n        \"vocab_size\": 125696,\n    },\n    \"internlm\": {\n        \"architectures\": [\"InternLMForCausalLM\"],\n        \"bias\": True,\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 4096,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 11008,\n        \"max_position_embeddings\": 2048,\n        \"model_type\": \"internlm\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 32,\n        \"pad_token_id\": 2,\n        \"rms_norm_eps\": 1e-06,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"float16\",\n        \"transformers_version\": \"4.33.2\",\n        \"use_cache\": True,\n        \"vocab_size\": 103168,\n    },\n    \"gemma_2b\": {\n        \"architectures\": [\"GemmaForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 2,\n        \"eos_token_id\": 1,\n        \"head_dim\": 256,\n        \"hidden_act\": \"gelu\",\n        \"hidden_size\": 2048,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 16384,\n        \"max_position_embeddings\": 8192,\n        \"model_type\": \"gemma\",\n        \"num_attention_heads\": 8,\n        \"num_hidden_layers\": 18,\n        \"num_key_value_heads\": 1,\n        \"pad_token_id\": 0,\n        \"rms_norm_eps\": 1e-06,\n        \"rope_scaling\": None,\n        \"rope_theta\": 10000.0,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.38.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": 256000,\n    },\n    \"gemma2_2b\": {\n        \"architectures\": [\"Gemma2ForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"attn_logit_softcapping\": 50.0,\n        \"bos_token_id\": 2,\n        \"cache_implementation\": \"hybrid\",\n        \"eos_token_id\": [1, 107],\n        \"final_logit_softcapping\": 30.0,\n        \"head_dim\": 256,\n        \"hidden_act\": \"gelu_pytorch_tanh\",\n        \"hidden_activation\": \"gelu_pytorch_tanh\",\n        \"hidden_size\": 2304,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 9216,\n        \"max_position_embeddings\": 8192,\n        \"model_type\": \"gemma2\",\n        \"num_attention_heads\": 8,\n        \"num_hidden_layers\": 26,\n        \"num_key_value_heads\": 4,\n        \"pad_token_id\": 0,\n        \"query_pre_attn_scalar\": 256,\n        \"rms_norm_eps\": 1e-06,\n        \"rope_theta\": 10000.0,\n        \"sliding_window\": 4096,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.42.4\",\n        \"use_cache\": True,\n        \"vocab_size\": 256000,\n    },\n    \"gemma2_2b-jpn\": {\n        \"architectures\": [\"Gemma2ForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"attn_logit_softcapping\": 50.0,\n        \"bos_token_id\": 2,\n        \"cache_implementation\": \"hybrid\",\n        \"dtype\": \"bfloat16\",\n        \"eos_token_id\": 1,\n        \"final_logit_softcapping\": 30.0,\n        \"head_dim\": 256,\n        \"hidden_activation\": \"gelu_pytorch_tanh\",\n        \"hidden_size\": 2304,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 9216,\n        \"max_position_embeddings\": 8192,\n        \"model_type\": \"gemma2\",\n        \"num_attention_heads\": 8,\n        \"num_hidden_layers\": 26,\n        \"num_key_value_heads\": 4,\n        \"pad_token_id\": 0,\n        \"query_pre_attn_scalar\": 224,\n        \"rms_norm_eps\": 1e-06,\n        \"rope_theta\": 10000.0,\n        \"sliding_window\": 4096,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.44.2\",\n        \"use_cache\": True,\n        \"vocab_size\": 256000,\n    },\n    \"gemma2_9b\": {\n        \"architectures\": [\"Gemma2ForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"attn_logit_softcapping\": 50.0,\n        \"bos_token_id\": 2,\n        \"cache_implementation\": \"hybrid\",\n        \"eos_token_id\": 1,\n        \"final_logit_softcapping\": 30.0,\n        \"head_dim\": 256,\n        \"hidden_act\": \"gelu_pytorch_tanh\",\n        \"hidden_activation\": \"gelu_pytorch_tanh\",\n        \"hidden_size\": 3584,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 14336,\n        \"max_position_embeddings\": 8192,\n        \"model_type\": \"gemma2\",\n        \"num_attention_heads\": 16,\n        \"num_hidden_layers\": 42,\n        \"num_key_value_heads\": 8,\n        \"pad_token_id\": 0,\n        \"query_pre_attn_scalar\": 256,\n        \"rms_norm_eps\": 1e-06,\n        \"rope_theta\": 10000.0,\n        \"sliding_window\": 4096,\n        \"sliding_window_size\": 4096,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.42.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": 256000,\n    },\n    \"gemma3_1b_it\": {\n        \"architectures\": [\"Gemma3ForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"attn_logit_softcapping\": None,\n        \"bos_token_id\": 2,\n        \"cache_implementation\": \"hybrid\",\n        \"eos_token_id\": [1, 106],\n        \"final_logit_softcapping\": None,\n        \"head_dim\": 256,\n        \"hidden_activation\": \"gelu_pytorch_tanh\",\n        \"hidden_size\": 1152,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 6912,\n        \"max_position_embeddings\": 32768,\n        \"model_type\": \"gemma3_text\",\n        \"num_attention_heads\": 4,\n        \"num_hidden_layers\": 26,\n        \"num_key_value_heads\": 1,\n        \"pad_token_id\": 0,\n        \"query_pre_attn_scalar\": 256,\n        \"rms_norm_eps\": 1e-06,\n        \"rope_local_base_freq\": 10000,\n        \"rope_scaling\": None,\n        \"rope_theta\": 1000000,\n        \"sliding_window\": 512,\n        \"sliding_window_pattern\": 6,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.50.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": 262144,\n    },\n    \"gemma2_27b\": {\n        \"architectures\": [\"Gemma2ForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"attn_logit_softcapping\": 50.0,\n        \"bos_token_id\": 2,\n        \"cache_implementation\": \"hybrid\",\n        \"eos_token_id\": 1,\n        \"final_logit_softcapping\": 30.0,\n        \"head_dim\": 128,\n        \"hidden_act\": \"gelu_pytorch_tanh\",\n        \"hidden_activation\": \"gelu_pytorch_tanh\",\n        \"hidden_size\": 4608,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 36864,\n        \"max_position_embeddings\": 8192,\n        \"model_type\": \"gemma2\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 46,\n        \"num_key_value_heads\": 16,\n        \"pad_token_id\": 0,\n        \"query_pre_attn_scalar\": 144,\n        \"rms_norm_eps\": 1e-06,\n        \"rope_theta\": 10000.0,\n        \"sliding_window\": 4096,\n        \"sliding_window_size\": 4096,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.42.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": 256000,\n        \"_attn_implementation\": \"eager\",\n    },\n    \"rwkv5_3b\": {\n        \"architectures\": [\"RwkvForCausalLM\"],\n        \"auto_map\": {\n            \"AutoConfig\": \"configuration_rwkv5.Rwkv5Config\",\n            \"AutoModelForCausalLM\": \"modeling_rwkv5.RwkvForCausalLM\",\n        },\n        \"attention_hidden_size\": 2560,\n        \"bos_token_id\": 0,\n        \"context_length\": 4096,\n        \"eos_token_id\": 0,\n        \"head_size\": 64,\n        \"hidden_size\": 2560,\n        \"intermediate_size\": None,\n        \"layer_norm_epsilon\": 1e-05,\n        \"model_type\": \"rwkv5\",\n        \"model_version\": \"5_2\",\n        \"num_hidden_layers\": 32,\n        \"rescale_every\": 6,\n        \"tie_word_embeddings\": True,\n        \"transformers_version\": \"4.34.0\",\n        \"use_cache\": True,\n        \"vocab_size\": 65536,\n    },\n    \"orion\": {\n        \"architectures\": [\"OrionForCausalLM\"],\n        \"auto_map\": {\n            \"AutoConfig\": \"configuration_orion.OrionConfig\",\n            \"AutoModelForCausalLM\": \"modeling_orion.OrionForCausalLM\",\n        },\n        \"tokenizer_class\": \"OrionTokenizer\",\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 5120,\n        \"model_type\": \"orion\",\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 15360,\n        \"max_position_embeddings\": 4096,\n        \"max_sequence_length\": 4096,\n        \"num_attention_heads\": 40,\n        \"num_hidden_layers\": 40,\n        \"num_key_value_heads\": 40,\n        \"pad_token_id\": 0,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": None,\n        \"rope_theta\": 10000.0,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.34.0\",\n        \"use_cache\": True,\n        \"vocab_size\": 84608,\n    },\n    \"llava\": {\n        \"architectures\": [\"LlavaForConditionalGeneration\"],\n        \"ignore_index\": -100,\n        \"image_token_index\": 32000,\n        \"model_type\": \"llava\",\n        \"pad_token_id\": 32001,\n        \"projector_hidden_act\": \"gelu\",\n        \"text_config\": {\n            \"_name_or_path\": \"meta-llama/Llama-2-7b-hf\",\n            \"architectures\": [\"LlamaForCausalLM\"],\n            \"max_position_embeddings\": 4096,\n            \"model_type\": \"llama\",\n            \"rms_norm_eps\": 1e-05,\n            \"torch_dtype\": \"float16\",\n            \"vocab_size\": 32064,\n        },\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"float16\",\n        \"transformers_version\": \"4.36.0.dev0\",\n        \"vision_config\": {\n            \"hidden_size\": 1024,\n            \"image_size\": 336,\n            \"intermediate_size\": 4096,\n            \"model_type\": \"clip_vision_model\",\n            \"num_attention_heads\": 16,\n            \"num_hidden_layers\": 24,\n            \"patch_size\": 14,\n            \"projection_dim\": 768,\n            \"vocab_size\": 32000,\n        },\n        \"vision_feature_layer\": -2,\n        \"vision_feature_select_strategy\": \"default\",\n        \"vocab_size\": 32064,\n    },\n    \"chatglm\": {\n        \"architectures\": [\"ChatGLMModel\"],\n        \"model_type\": \"chatglm\",\n        \"add_bias_linear\": False,\n        \"add_qkv_bias\": True,\n        \"apply_query_key_layer_scaling\": True,\n        \"apply_residual_connection_post_layernorm\": False,\n        \"attention_dropout\": 0.0,\n        \"attention_softmax_in_fp32\": True,\n        \"bias_dropout_fusion\": True,\n        \"ffn_hidden_size\": 13696,\n        \"fp32_residual_connection\": False,\n        \"hidden_dropout\": 0.0,\n        \"hidden_size\": 4096,\n        \"kv_channels\": 128,\n        \"layernorm_epsilon\": 1e-05,\n        \"multi_query_attention\": True,\n        \"multi_query_group_num\": 2,\n        \"num_attention_heads\": 32,\n        \"num_layers\": 28,\n        \"original_rope\": True,\n        \"padded_vocab_size\": 65024,\n        \"post_layer_norm\": True,\n        \"rmsnorm\": True,\n        \"seq_length\": 8192,\n        \"use_cache\": True,\n        \"torch_dtype\": \"float16\",\n        \"transformers_version\": \"4.30.2\",\n        \"tie_word_embeddings\": False,\n        \"eos_token_id\": 2,\n        \"pad_token_id\": 0,\n    },\n    \"llama3_1_8b\": {\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 128000,\n        \"eos_token_id\": [128001, 128008, 128009],\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 4096,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 14336,\n        \"max_position_embeddings\": 131072,\n        \"mlp_bias\": False,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 32,\n        \"num_key_value_heads\": 8,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": {\n            \"factor\": 8.0,\n            \"low_freq_factor\": 1.0,\n            \"high_freq_factor\": 4.0,\n            \"original_max_position_embeddings\": 8192,\n            \"rope_type\": \"llama3\",\n        },\n        \"rope_theta\": 500000.0,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.42.3\",\n        \"use_cache\": True,\n        \"vocab_size\": 128256,\n    },\n    \"llama3_1_70b\": {\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 128000,\n        \"eos_token_id\": [128001, 128008, 128009],\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 8192,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 28672,\n        \"max_position_embeddings\": 131072,\n        \"mlp_bias\": False,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": 64,\n        \"num_hidden_layers\": 80,\n        \"num_key_value_heads\": 8,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": {\n            \"factor\": 8.0,\n            \"low_freq_factor\": 1.0,\n            \"high_freq_factor\": 4.0,\n            \"original_max_position_embeddings\": 8192,\n            \"rope_type\": \"llama3\",\n        },\n        \"rope_theta\": 500000.0,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.42.3\",\n        \"use_cache\": True,\n        \"vocab_size\": 128256,\n    },\n    \"llama3_2_1b\": {\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 128000,\n        \"eos_token_id\": [128001, 128008, 128009],\n        \"head_dim\": 64,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 2048,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 8192,\n        \"max_position_embeddings\": 131072,\n        \"mlp_bias\": False,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 16,\n        \"num_key_value_heads\": 8,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": {\n            \"factor\": 32.0,\n            \"high_freq_factor\": 4.0,\n            \"low_freq_factor\": 1.0,\n            \"original_max_position_embeddings\": 8192,\n            \"rope_type\": \"llama3\",\n        },\n        \"rope_theta\": 500000.0,\n        \"tie_word_embeddings\": True,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.45.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": 128256,\n    },\n    \"llama3_2_3b\": {\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 128000,\n        \"eos_token_id\": [128001, 128008, 128009],\n        \"head_dim\": 128,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 3072,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 8192,\n        \"max_position_embeddings\": 131072,\n        \"mlp_bias\": False,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": 24,\n        \"num_hidden_layers\": 28,\n        \"num_key_value_heads\": 8,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": {\n            \"factor\": 32.0,\n            \"high_freq_factor\": 4.0,\n            \"low_freq_factor\": 1.0,\n            \"original_max_position_embeddings\": 8192,\n            \"rope_type\": \"llama3\",\n        },\n        \"rope_theta\": 500000.0,\n        \"tie_word_embeddings\": True,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.45.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": 128256,\n    },\n    \"snowflake-arctic-embed-m\": {\n        \"architectures\": [\"BertModel\"],\n        \"attention_probs_dropout_prob\": 0.1,\n        \"classifier_dropout\": None,\n        \"gradient_checkpointing\": False,\n        \"hidden_act\": \"gelu\",\n        \"hidden_dropout_prob\": 0.1,\n        \"hidden_size\": 768,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 3072,\n        \"layer_norm_eps\": 1e-12,\n        \"max_position_embeddings\": 512,\n        \"model_type\": \"bert\",\n        \"num_attention_heads\": 12,\n        \"num_hidden_layers\": 12,\n        \"pad_token_id\": 0,\n        \"position_embedding_type\": \"absolute\",\n        \"torch_dtype\": \"float32\",\n        \"transformers_version\": \"4.36.1\",\n        \"type_vocab_size\": 2,\n        \"use_cache\": True,\n        \"vocab_size\": 30522,\n    },\n    # \"snowflake-arctic-embed-s\": {\n    #     \"architectures\": [\"BertModel\"],\n    #     \"attention_probs_dropout_prob\": 0.1,\n    #     \"classifier_dropout\": None,\n    #     \"hidden_act\": \"gelu\",\n    #     \"hidden_dropout_prob\": 0.1,\n    #     \"hidden_size\": 384,\n    #     \"initializer_range\": 0.02,\n    #     \"intermediate_size\": 1536,\n    #     \"layer_norm_eps\": 1e-12,\n    #     \"max_position_embeddings\": 512,\n    #     \"model_type\": \"bert\",\n    #     \"num_attention_heads\": 12,\n    #     \"num_hidden_layers\": 12,\n    #     \"pad_token_id\": 0,\n    #     \"position_embedding_type\": \"absolute\",\n    #     \"torch_dtype\": \"float32\",\n    #     \"transformers_version\": \"4.36.1\",\n    #     \"type_vocab_size\": 2,\n    #     \"use_cache\": True,\n    #     \"vocab_size\": 30522,\n    # },\n    \"stablelm-2-zephyr-1_6b\": {\n        \"architectures\": [\"StableLmForCausalLM\"],\n        \"bos_token_id\": 100257,\n        \"eos_token_id\": 100257,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 2048,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 5632,\n        \"max_position_embeddings\": 4096,\n        \"model_type\": \"stablelm\",\n        \"layer_norm_eps\": 1e-05,\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 24,\n        \"num_key_value_heads\": 32,\n        \"partial_rotary_factor\": 0.25,\n        \"rope_theta\": 10000,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"float16\",\n        \"transformers_version\": \"4.38.0\",\n        \"use_cache\": True,\n        \"use_qkv_bias\": True,\n        \"vocab_size\": 100352,\n    },\n    \"qwen2_0_5b\": {\n        \"architectures\": [\"Qwen2ForCausalLM\"],\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 151643,\n        \"eos_token_id\": 151645,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 896,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 4864,\n        \"max_position_embeddings\": 32768,\n        \"max_window_layers\": 24,\n        \"model_type\": \"qwen2\",\n        \"num_attention_heads\": 14,\n        \"num_hidden_layers\": 24,\n        \"num_key_value_heads\": 2,\n        \"rms_norm_eps\": 1e-06,\n        \"rope_theta\": 1000000.0,\n        \"sliding_window\": 32768,\n        \"tie_word_embeddings\": True,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.40.1\",\n        \"use_cache\": True,\n        \"use_sliding_window\": False,\n        \"vocab_size\": 151936,\n    },\n    \"qwen2_1_5b\": {\n        \"architectures\": [\"Qwen2ForCausalLM\"],\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 151643,\n        \"eos_token_id\": 151645,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 1536,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 8960,\n        \"max_position_embeddings\": 32768,\n        \"max_window_layers\": 28,\n        \"model_type\": \"qwen2\",\n        \"num_attention_heads\": 12,\n        \"num_hidden_layers\": 28,\n        \"num_key_value_heads\": 2,\n        \"rms_norm_eps\": 1e-06,\n        \"rope_theta\": 1000000.0,\n        \"sliding_window\": 32768,\n        \"tie_word_embeddings\": True,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.40.1\",\n        \"use_cache\": True,\n        \"use_sliding_window\": False,\n        \"vocab_size\": 151936,\n    },\n    \"qwen2.5_3b\": {\n        \"architectures\": [\"Qwen2ForCausalLM\"],\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 151643,\n        \"eos_token_id\": 151645,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 2048,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 11008,\n        \"max_position_embeddings\": 32768,\n        \"max_window_layers\": 70,\n        \"model_type\": \"qwen2\",\n        \"num_attention_heads\": 16,\n        \"num_hidden_layers\": 36,\n        \"num_key_value_heads\": 2,\n        \"rms_norm_eps\": 1e-06,\n        \"rope_theta\": 1000000.0,\n        \"sliding_window\": 32768,\n        \"tie_word_embeddings\": True,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.43.1\",\n        \"use_cache\": True,\n        \"use_sliding_window\": False,\n        \"vocab_size\": 151936,\n    },\n    \"qwen2_7b\": {\n        \"architectures\": [\"Qwen2ForCausalLM\"],\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 151643,\n        \"eos_token_id\": 151645,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 3584,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 18944,\n        \"max_position_embeddings\": 32768,\n        \"max_window_layers\": 28,\n        \"model_type\": \"qwen2\",\n        \"num_attention_heads\": 28,\n        \"num_hidden_layers\": 28,\n        \"num_key_value_heads\": 4,\n        \"rms_norm_eps\": 1e-06,\n        \"rope_theta\": 1000000.0,\n        \"sliding_window\": 131072,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.41.2\",\n        \"use_cache\": True,\n        \"use_sliding_window\": False,\n        \"vocab_size\": 152064,\n    },\n    \"qwen3_0.6b\": {\n        \"architectures\": [\"Qwen3ForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 151643,\n        \"eos_token_id\": 151645,\n        \"head_dim\": 128,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 1024,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 3072,\n        \"max_position_embeddings\": 40960,\n        \"max_window_layers\": 28,\n        \"model_type\": \"qwen3\",\n        \"num_attention_heads\": 16,\n        \"num_hidden_layers\": 28,\n        \"num_key_value_heads\": 8,\n        \"rms_norm_eps\": 1e-06,\n        \"rope_scaling\": None,\n        \"rope_theta\": 1000000,\n        \"sliding_window\": None,\n        \"tie_word_embeddings\": True,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.51.0\",\n        \"use_cache\": True,\n        \"use_sliding_window\": False,\n        \"vocab_size\": 151936,\n    },\n    \"qwen3_1.7b\": {\n        \"architectures\": [\"Qwen3ForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 151643,\n        \"eos_token_id\": 151645,\n        \"head_dim\": 128,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 2048,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 6144,\n        \"max_position_embeddings\": 40960,\n        \"max_window_layers\": 28,\n        \"model_type\": \"qwen3\",\n        \"num_attention_heads\": 16,\n        \"num_hidden_layers\": 28,\n        \"num_key_value_heads\": 8,\n        \"rms_norm_eps\": 1e-06,\n        \"rope_scaling\": None,\n        \"rope_theta\": 1000000,\n        \"sliding_window\": None,\n        \"tie_word_embeddings\": True,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.51.0\",\n        \"use_cache\": True,\n        \"use_sliding_window\": False,\n        \"vocab_size\": 151936,\n    },\n    \"internlm2\": {\n        \"architectures\": [\"InternLM2ForCausalLM\"],\n        \"attn_implementation\": \"eager\",\n        \"bias\": False,\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 4096,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 14336,\n        \"max_position_embeddings\": 32768,\n        \"model_type\": \"internlm2\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 32,\n        \"num_key_value_heads\": 8,\n        \"pad_token_id\": 2,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": None,\n        \"rope_theta\": 1000000,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.37.1\",\n        \"use_cache\": True,\n        \"vocab_size\": 92544,\n    },\n    \"internlm2_5_7b\": {\n        \"architectures\": [\"InternLM2ForCausalLM\"],\n        \"attn_implementation\": \"eager\",\n        \"bias\": False,\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 4096,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 14336,\n        \"max_position_embeddings\": 32768,\n        \"model_type\": \"internlm2\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 32,\n        \"num_key_value_heads\": 8,\n        \"pad_token_id\": 2,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": {\"type\": \"dynamic\", \"factor\": 2.0},\n        \"rope_theta\": 1000000,\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.41.0\",\n        \"use_cache\": True,\n        \"vocab_size\": 92544,\n        \"pretraining_tp\": 1,\n    },\n    \"starcoder2\": {\n        \"activation_function\": \"gelu\",\n        \"architectures\": [\"Starcoder2ForCausalLM\"],\n        \"attention_dropout\": 0.1,\n        \"residual_dropout\": 0.1,\n        \"embedding_dropout\": 0.1,\n        \"attention_softmax_in_fp32\": True,\n        \"bos_token_id\": 0,\n        \"eos_token_id\": 0,\n        \"hidden_act\": \"gelu_pytorch_tanh\",\n        \"hidden_size\": 4608,\n        \"initializer_range\": 0.018042,\n        \"intermediate_size\": 18432,\n        \"layer_norm_epsilon\": 1e-05,\n        \"max_position_embeddings\": 16384,\n        \"mlp_type\": \"default\",\n        \"model_type\": \"starcoder2\",\n        \"norm_epsilon\": 1e-05,\n        \"norm_type\": \"layer_norm\",\n        \"num_attention_heads\": 36,\n        \"num_hidden_layers\": 32,\n        \"num_key_value_heads\": 4,\n        \"rope_theta\": 1000000,\n        \"scale_attention_softmax_in_fp32\": True,\n        \"scale_attn_weights\": True,\n        \"sliding_window\": 4096,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.37.0.dev0\",\n        \"use_bias\": True,\n        \"use_cache\": True,\n        \"vocab_size\": 49152,\n    },\n    \"smollm_1_7b\": {\n        \"_name_or_path\": \"HuggingFaceTB/cosmo2-1.7B-webinst-sc2\",\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 2048,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 8192,\n        \"max_position_embeddings\": 2048,\n        \"mlp_bias\": False,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 24,\n        \"num_key_value_heads\": 32,\n        \"pad_token_id\": 2,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": None,\n        \"rope_theta\": 10000.0,\n        \"tie_word_embeddings\": True,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.42.3\",\n        \"use_cache\": True,\n        \"vocab_size\": 49152,\n    },\n    \"smollm_360m\": {\n        \"_name_or_path\": \"HuggingFaceTB/cosmo2-350M-webinst-sc2\",\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 960,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 2560,\n        \"max_position_embeddings\": 2048,\n        \"mlp_bias\": False,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": 15,\n        \"num_hidden_layers\": 32,\n        \"num_key_value_heads\": 5,\n        \"pad_token_id\": 2,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": None,\n        \"rope_theta\": 10000.0,\n        \"tie_word_embeddings\": True,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.42.3\",\n        \"use_cache\": True,\n        \"vocab_size\": 49152,\n    },\n    \"smollm_135m\": {\n        \"_name_or_path\": \"HuggingFaceTB/cosmo2-135M-webinst-sc2\",\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 576,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 1536,\n        \"max_position_embeddings\": 2048,\n        \"mlp_bias\": False,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": 9,\n        \"num_hidden_layers\": 30,\n        \"num_key_value_heads\": 3,\n        \"pad_token_id\": 2,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": None,\n        \"rope_theta\": 10000.0,\n        \"tie_word_embeddings\": True,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.42.3\",\n        \"use_cache\": True,\n        \"vocab_size\": 49152,\n    },\n    \"smollm2_135m\": {\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 576,\n        \"initializer_range\": 0.041666666666666664,\n        \"intermediate_size\": 1536,\n        \"is_llama_config\": True,\n        \"max_position_embeddings\": 8192,\n        \"mlp_bias\": False,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": 9,\n        \"num_hidden_layers\": 30,\n        \"num_key_value_heads\": 3,\n        \"pad_token_id\": 2,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_interleaved\": False,\n        \"rope_scaling\": None,\n        \"rope_theta\": 100000,\n        \"tie_word_embeddings\": True,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.42.3\",\n        \"transformers.js_config\": {\n            \"kv_cache_dtype\": {\n                \"q4f16\": \"float16\",\n                \"fp16\": \"float16\",\n            }\n        },\n        \"use_cache\": True,\n        \"vocab_size\": 49152,\n    },\n    # \"smollm2_1_7b\": {\n    #     \"architectures\": [\"LlamaForCausalLM\"],\n    #     \"attention_bias\": False,\n    #     \"attention_dropout\": 0.0,\n    #     \"bos_token_id\": 1,\n    #     \"eos_token_id\": 2,\n    #     \"hidden_act\": \"silu\",\n    #     \"hidden_size\": 2048,\n    #     \"initializer_range\": 0.02,\n    #     \"intermediate_size\": 8192,\n    #     \"max_position_embeddings\": 8192,\n    #     \"mlp_bias\": False,\n    #     \"model_type\": \"llama\",\n    #     \"num_attention_heads\": 32,\n    #     \"num_hidden_layers\": 24,\n    #     \"num_key_value_heads\": 32,\n    #     \"pad_token_id\": 2,\n    #     \"pretraining_tp\": 1,\n    #     \"rms_norm_eps\": 1e-05,\n    #     \"rope_scaling\": None,\n    #     \"rope_theta\": 130000,\n    #     \"tie_word_embeddings\": True,\n    #     \"torch_dtype\": \"bfloat16\",\n    #     \"transformers_version\": \"4.42.3\",\n    #     \"transformers.js_config\": {\n    #         \"dtype\": \"q4\",\n    #         \"kv_cache_dtype\": {\n    #             \"q4f16\": \"float16\",\n    #             \"fp16\": \"float16\",\n    #         },\n    #         \"use_external_data_format\": {\n    #             \"model.onnx\": True,\n    #             \"model_fp16.onnx\": True,\n    #         },\n    #     },\n    #     \"use_cache\": True,\n    #     \"vocab_size\": 49152,\n    # },\n    \"smollm2_360m\": {\n        \"architectures\": [\"LlamaForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 960,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 2560,\n        \"is_llama_config\": True,\n        \"max_position_embeddings\": 8192,\n        \"mlp_bias\": False,\n        \"model_type\": \"llama\",\n        \"num_attention_heads\": 15,\n        \"num_hidden_layers\": 32,\n        \"num_key_value_heads\": 5,\n        \"pad_token_id\": 2,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_interleaved\": False,\n        \"rope_scaling\": None,\n        \"rope_theta\": 100000,\n        \"tie_word_embeddings\": True,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.42.3\",\n        \"transformers.js_config\": {\n            \"kv_cache_dtype\": {\n                \"q4f16\": \"float16\",\n                \"fp16\": \"float16\",\n            }\n        },\n        \"use_cache\": True,\n        \"vocab_size\": 49152,\n    },\n    \"aya-23\": {\n        \"architectures\": [\"CohereForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 5,\n        \"eos_token_id\": 255001,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 4096,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 14336,\n        \"layer_norm_eps\": 1e-05,\n        \"logit_scale\": 0.0625,\n        \"max_position_embeddings\": 8192,\n        \"model_type\": \"cohere\",\n        \"num_attention_heads\": 32,\n        \"num_hidden_layers\": 32,\n        \"num_key_value_heads\": 8,\n        \"pad_token_id\": 0,\n        \"rope_theta\": 10000,\n        \"torch_dtype\": \"float16\",\n        \"transformers_version\": \"4.40.0.dev0\",\n        \"use_cache\": True,\n        \"use_qk_norm\": False,\n        \"vocab_size\": 256000,\n    },\n    \"minicpm_2b\": {\n        \"architectures\": [\"MiniCPMForCausalLM\"],\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 2304,\n        \"initializer_range\": 0.1,\n        \"intermediate_size\": 5760,\n        \"max_position_embeddings\": 65536,\n        \"max_length\": 131072,\n        \"model_type\": \"minicpm\",\n        \"num_attention_heads\": 36,\n        \"num_hidden_layers\": 40,\n        \"num_key_value_heads\": 36,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": {\"type\": \"dynamic\", \"factor\": 4.0},\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.36.0\",\n        \"use_cache\": True,\n        \"vocab_size\": 122760,\n        \"scale_emb\": 12,\n        \"dim_model_base\": 256,\n        \"scale_depth\": 1.4,\n        \"tie_word_embeddings\": False,\n        \"rope_theta\": 1000000.0,\n    },\n    \"minicpm_2b_sft_bf16\": {\n        \"architectures\": [\"MiniCPMForCausalLM\"],\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 2304,\n        \"initializer_range\": 0.1,\n        \"intermediate_size\": 5760,\n        \"max_position_embeddings\": 4096,\n        \"model_type\": \"minicpm\",\n        \"num_attention_heads\": 36,\n        \"num_hidden_layers\": 40,\n        \"num_key_value_heads\": 36,\n        \"rms_norm_eps\": 1e-05,\n        \"torch_dtype\": \"bfloat16\",\n        \"tie_word_embeddings\": True,\n        \"transformers_version\": \"4.36.0\",\n        \"use_cache\": True,\n        \"vocab_size\": 122753,\n        \"scale_emb\": 12,\n        \"dim_model_base\": 256,\n        \"scale_depth\": 1.4,\n    },\n    \"minicpm-moe-8x2b\": {\n        \"architectures\": [\"MiniCPMForCausalLM\"],\n        \"bos_token_id\": 1,\n        \"eos_token_id\": 2,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 2304,\n        \"initializer_range\": 0.1,\n        \"intermediate_size\": 5760,\n        \"max_position_embeddings\": 4096,\n        \"model_type\": \"minicpm\",\n        \"num_attention_heads\": 36,\n        \"num_hidden_layers\": 40,\n        \"num_key_value_heads\": 36,\n        \"rms_norm_eps\": 1e-05,\n        \"rope_scaling\": None,\n        \"torch_dtype\": \"bfloat16\",\n        \"tie_word_embeddings\": True,\n        \"transformers_version\": \"4.36.0\",\n        \"use_cache\": True,\n        \"vocab_size\": 122753,\n        \"scale_emb\": 12,\n        \"dim_model_base\": 256,\n        \"scale_depth\": 1.4,\n        \"num_experts\": 8,\n        \"num_experts_per_tok\": 2,\n    },\n    \"deepseek\": {\n        \"architectures\": [\"DeepseekForCausalLM\"],\n        \"attention_bias\": False,\n        \"attention_dropout\": 0.0,\n        \"bos_token_id\": 100000,\n        \"eos_token_id\": 100001,\n        \"first_k_dense_replace\": 1,\n        \"hidden_act\": \"silu\",\n        \"hidden_size\": 2048,\n        \"initializer_range\": 0.02,\n        \"intermediate_size\": 10944,\n        \"max_position_embeddings\": 4096,\n        \"model_type\": \"deepseek\",\n        \"moe_intermediate_size\": 1408,\n        \"moe_layer_freq\": 1,\n        \"n_routed_experts\": 64,\n        \"n_shared_experts\": 2,\n        \"norm_topk_prob\": False,\n        \"num_attention_heads\": 16,\n        \"num_experts_per_tok\": 6,\n        \"num_hidden_layers\": 28,\n        \"num_key_value_heads\": 16,\n        \"pretraining_tp\": 1,\n        \"rms_norm_eps\": 1e-06,\n        \"rope_scaling\": None,\n        \"rope_theta\": 10000,\n        \"scoring_func\": \"softmax\",\n        \"tie_word_embeddings\": False,\n        \"torch_dtype\": \"bfloat16\",\n        \"transformers_version\": \"4.36.2\",\n        \"use_cache\": True,\n        \"vocab_size\": 102400,\n    },\n    \"gpt_j\": {\n        \"activation_function\": \"gelu_new\",\n        \"architectures\": [\"GPTJForCausalLM\"],\n        \"attn_pdrop\": 0.0,\n        \"bos_token_id\": 50256,\n        \"embd_pdrop\": 0.0,\n        \"eos_token_id\": 50256,\n        \"initializer_range\": 0.02,\n        \"layer_norm_epsilon\": 1e-05,\n        \"model_type\": \"gptj\",\n        \"n_embd\": 4096,\n        \"n_head\": 16,\n        \"n_inner\": None,\n        \"n_layer\": 28,\n        \"n_positions\": 2048,\n        \"resid_pdrop\": 0.0,\n        \"rotary\": True,\n        \"rotary_dim\": 64,\n        \"scale_attn_weights\": True,\n        \"summary_activation\": None,\n        \"summary_first_dropout\": 0.1,\n        \"summary_proj_to_labels\": True,\n        \"summary_type\": \"cls_index\",\n        \"summary_use_proj\": True,\n        \"rope_scaling\": {\"rope_type\": \"gptj\"},\n        \"tie_word_embeddings\": False,\n        \"tokenizer_class\": \"GPT2Tokenizer\",\n        \"transformers_version\": \"4.18.0.dev0\",\n        \"use_cache\": True,\n        \"vocab_size\": 50400,\n    },\n}\n"
  },
  {
    "path": "python/mlc_llm/model/nemotron/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/nemotron/nemotron_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Nemotron parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\n\nfrom .nemotron_model import NemotronForCausalLM\n\nhuggingface = make_standard_hf_loader(\n    model_cls=NemotronForCausalLM,\n    add_unused=[\"rotary_emb.inv_freq\"],\n    include_gate_up=False,\n)\n"
  },
  {
    "path": "python/mlc_llm/model/nemotron/nemotron_model.py",
    "content": "\"\"\"\nImplementation for Nemotron architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass NemotronConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Nemotron model.\"\"\"\n\n    vocab_size: int\n    max_position_embeddings: int\n    hidden_size: int\n    intermediate_size: int\n    num_hidden_layers: int\n    num_attention_heads: int\n    num_key_value_heads: int\n    rope_theta: int = 10000\n    partial_rotary_factor: float = 0.5\n    rope_scaling: Optional[Dict[str, Any]] = None\n    norm_eps: float = 1e-5\n    head_dim: int = 0\n    tie_word_embeddings: bool = False\n    mlp_bias: bool = False\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    pipeline_parallel_stages: int = 1\n    max_batch_size: int = 1\n    disaggregation: bool = False\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):  # pylint: disable=too-many-branches\n        if self.context_window_size == 0:\n            self.context_window_size = self.max_position_embeddings\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.head_dim * self.num_attention_heads == self.hidden_size\n        self.rotary_dim = int(self.partial_rotary_factor * self.head_dim)\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass NemotronMLP(nn.Module):\n    \"\"\"Nemotron MLP module.\"\"\"\n\n    def __init__(self, config: NemotronConfig):\n        super().__init__()\n        self.hidden_size = config.hidden_size\n        self.intermediate_size = config.intermediate_size\n\n        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_bias)\n        self.down_proj = nn.Linear(\n            config.intermediate_size, config.hidden_size, bias=config.mlp_bias\n        )\n\n    def forward(self, x: Tensor) -> Tensor:\n        \"\"\"Forward pass of the MLP module.\"\"\"\n        out = self.up_proj(x)\n        out = op.square(op.relu(out))\n        out = self.down_proj(out)\n        return out\n\n\nclass NemotronEmbedding(nn.Embedding):\n    \"\"\"The embedding module that can be shared with the final head. From Qwen2Embedding.\"\"\"\n\n    def lm_head_forward(self, x: Tensor):\n        \"\"\"The lm_head forwarding, which transposes the weight and multiplies\n        with the input tensor.\n        \"\"\"\n        weight = nn.op.permute_dims(self.weight)\n        return nn.op.matmul(x, weight, out_dtype=\"float32\")\n\n\nclass NemotronLayerNorm1P(nn.LayerNorm):\n    \"\"\"Nemotron LayerNorm1P module.\"\"\"\n\n    def __init__(self, normalized_shape: int, eps: float = 1e-5, elementwise_affine: bool = True):\n        super().__init__(normalized_shape, eps, elementwise_affine)\n\n    def forward(self, x: Tensor) -> Tensor:\n        \"\"\"Forward pass of the tweaked LayerNorm module.\"\"\"\n        return op.layer_norm(\n            x,\n            normalized_shape=self.normalized_shape,\n            weight=self.weight + 1,\n            bias=self.bias,\n            eps=self.eps,\n        )\n\n\nclass NemotronAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: NemotronConfig):\n        self.head_dim = config.head_dim\n        self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards\n        assert (\n            config.num_key_value_heads % config.tensor_parallel_shards == 0\n        ), f\"num_kv_heads({config.num_key_value_heads}) must be divisible by tensor_parallel_shards\"\n        assert (\n            config.num_key_value_heads >= config.tensor_parallel_shards\n        ), f\"Too large tensor_parallel_shards, must be smaller than {config.num_key_value_heads}\"\n        self.num_kv_heads = config.num_key_value_heads // config.tensor_parallel_shards\n        self.qkv_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim,\n            bias=False,\n        )\n        self.o_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False)\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads\n        b, s, _ = hidden_states.shape\n        # QKV Projection\n        qkv = self.qkv_proj(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        # Attention\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_q_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        return self.o_proj(output)\n\n\nclass NemotronDecoderLayer(nn.Module):\n    def __init__(self, config: NemotronConfig):\n        self.self_attn = NemotronAttention(config)\n        self.mlp = NemotronMLP(config)\n        self.input_layernorm = NemotronLayerNorm1P(config.hidden_size, config.norm_eps)\n        self.post_attention_layernorm = NemotronLayerNorm1P(config.hidden_size, config.norm_eps)\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.weight.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_q_heads * hd\n            k = self.self_attn.num_kv_heads * hd\n            v = self.self_attn.num_kv_heads * hd\n            _set(\n                self.self_attn.qkv_proj,\n                tp.ShardSingleDim(\"_shard_qkv\", segs=[q, k, v], dim=0),\n            )\n            _set(self.self_attn.o_proj, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(self.mlp.up_proj, tp.ShardSingleDim(\"_shard_mlp_up\", dim=1))\n            _set(self.mlp.down_proj, tp.ShardSingleDim(\"_shard_mlp_down\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.mlp(self.post_attention_layernorm(hidden_states))\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass NemotronModel(nn.Module):\n    def __init__(self, config: NemotronConfig):\n        assert config.hidden_size % config.num_attention_heads == 0\n        self.embed_tokens = NemotronEmbedding(\"vocab_size\", config.hidden_size)\n        self.layers = nn.ModuleList(\n            [NemotronDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = NemotronLayerNorm1P(config.hidden_size, config.norm_eps)\n        self.num_layers_per_stage = (\n            config.num_hidden_layers + config.pipeline_parallel_stages - 1\n        ) // config.pipeline_parallel_stages\n\n        # Compute pipeline layer partition.\n        layers_per_stage = (\n            config.num_hidden_layers + config.pipeline_parallel_stages - 1\n        ) // config.pipeline_parallel_stages\n        self.layer_partition = [\n            i * layers_per_stage for i in range(config.pipeline_parallel_stages)\n        ] + [config.num_hidden_layers]\n\n    def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = input_embed\n        for layer_id, layer in enumerate(self.layers):\n            if layer_id != 0 and layer_id in self.layer_partition:\n                hidden_states = op_ext.pipeline_stage_boundary(hidden_states)\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass NemotronForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: NemotronConfig):\n        self.model = NemotronModel(config)\n        self.tie_word_embeddings = config.tie_word_embeddings\n        if not config.tie_word_embeddings:\n            self.lm_head = nn.Linear(config.hidden_size, \"vocab_size\", bias=False)\n        self.num_hidden_layers = config.num_hidden_layers\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.head_dim = config.head_dim\n        self.hidden_size = config.hidden_size\n        self.vocab_size = config.vocab_size\n        self.rope_scaling = config.rope_scaling\n        self.rope_theta = config.rope_theta\n        self.rotary_dim = config.rotary_dim\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.disaggregation = config.disaggregation\n        self.dtype = \"float32\"\n\n        def _set_pp():\n            # hidden layers\n            for layer_id in range(config.num_hidden_layers):\n                stage = layer_id // (config.num_hidden_layers // config.pipeline_parallel_stages)\n                for _, param in self.model.layers[layer_id].named_parameters():\n                    param.attrs[\"pipeline_stages\"] = [stage]\n            # last stage\n            last_stage = config.pipeline_parallel_stages - 1\n            self.model.norm.weight.attrs[\"pipeline_stages\"] = [last_stage]\n            # embedding table and lm_head is required by all stages\n            all_stages = list(range(config.pipeline_parallel_stages))\n            self.model.embed_tokens.weight.attrs[\"pipeline_stages\"] = all_stages\n            if not config.tie_word_embeddings:\n                self.lm_head.weight.attrs[\"pipeline_stages\"] = all_stages\n\n        _set_pp()\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            if self.tensor_parallel_shards > 1:\n                logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        return self.get_logits(hidden_states)\n\n    def batch_forward_to_last_hidden_states(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        return hidden_states\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def get_logits(self, hidden_states: Tensor):\n        op_ext.configure()\n        if self.tie_word_embeddings:\n            logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def batch_select_last_hidden_states(self, hidden_states: Tensor, logit_positions: Tensor):\n        op_ext.configure()\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        hidden_states = op.take(hidden_states, logit_positions, axis=0)\n        return hidden_states\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.get_logits(hidden_states)\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.get_logits(hidden_states)\n        return logits, paged_kv_cache\n\n    def prefill_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def decode_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_prefill_to_last_hidden_states(\n        self, input_embeds: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def batch_decode_to_last_hidden_states(\n        self, input_embeds: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def batch_verify_to_last_hidden_states(\n        self, input_embeds: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            rope_scaling=self.rope_scaling,\n            rotary_dim=self.rotary_dim,\n            layer_partition=self.model.layer_partition,\n            enable_disaggregation=self.disaggregation,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"get_logits\": {\n                \"hidden_states\": nn.spec.Tensor([\"seq_len\", self.hidden_size], self.dtype),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_select_last_hidden_states\": {\n                \"hidden_states\": nn.spec.Tensor([\"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill_to_last_hidden_states\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode_to_last_hidden_states\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill_to_last_hidden_states\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode_to_last_hidden_states\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify_to_last_hidden_states\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/olmo/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/olmo/olmo_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's OLMo parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\nfrom mlc_llm.quantization import Quantization, make_awq_quant\n\nfrom .olmo_model import OLMoConfig, OLMoForCausalLM\n\nawq_quant = make_awq_quant(OLMoForCausalLM)\n\n\nhuggingface = make_standard_hf_loader(\n    model_cls=OLMoForCausalLM,\n    add_unused=[\"rotary_emb.inv_freq\"],\n)\n\n\ndef awq(model_config: OLMoConfig, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of AWQ parameters.\n    Parameters\n    ----------\n    model_config : OLMoConfig\n        The configuration of the OLMo model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to AWQ.\n    \"\"\"\n    model, _ = awq_quant(model_config, quantization)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),  # type: ignore[attr-defined]\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    for i in range(model_config.num_hidden_layers):\n        # Add QKV in self attention\n        attn = f\"model.layers.{i}.self_attn\"\n        for quantize_suffix in [\"qweight\", \"qzeros\", \"scales\"]:\n            mlc_name = f\"{attn}.qkv_proj.{quantize_suffix}\"\n            assert mlc_name in named_parameters\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{attn}.q_proj.{quantize_suffix}\",\n                    f\"{attn}.k_proj.{quantize_suffix}\",\n                    f\"{attn}.v_proj.{quantize_suffix}\",\n                ],\n                functools.partial(\n                    lambda q, k, v, dtype: np.concatenate(\n                        [q, k, v],\n                        axis=1,  # AWQ GEMM would transpose the weight\n                    ).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n        # Concat gate and up in MLP\n        mlp = f\"model.layers.{i}.mlp\"\n        for quantize_suffix in [\"qweight\", \"qzeros\", \"scales\"]:\n            mlc_name = f\"{mlp}.gate_up_proj.{quantize_suffix}\"\n            assert mlc_name in named_parameters\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{mlp}.gate_proj.{quantize_suffix}\",\n                    f\"{mlp}.up_proj.{quantize_suffix}\",\n                ],\n                functools.partial(\n                    lambda gate, up, dtype: np.concatenate(\n                        [gate, up],\n                        axis=1,  # AWQ GEMM would transpose the weight\n                    ).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n        # inv_freq is not used in the model\n        mapping.add_unused(f\"{attn}.rotary_emb.inv_freq\")\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(lambda x, dtype: x.astype(dtype), dtype=mlc_param.dtype),\n            )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/olmo/olmo_model.py",
    "content": "\"\"\"\nImplementation for OLMo architecture.\nTODO: add docstring\n\"\"\"\n\nimport dataclasses\nfrom functools import partial\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass OLMoConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the OLMo model.\"\"\"\n\n    vocab_size: int = None\n    hidden_size: int = None\n    num_attention_heads: int = None\n    num_key_value_heads: int = 0\n    head_dim: int = 0\n    position_embedding_base: int = 0\n    rope_scaling: Optional[Dict[str, Any]] = None\n    intermediate_size: int = None\n    hidden_act: str = None\n    num_hidden_layers: int = None\n    tie_word_embeddings: bool = False\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    pipeline_parallel_stages: int = 1\n    max_batch_size: int = 1\n    clip_qkv: float = None\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):  # pylint: disable=too-many-branches\n        if self.num_key_value_heads == 0:\n            self.num_key_value_heads = self.num_attention_heads\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.num_attention_heads % self.num_key_value_heads == 0\n\n        if self.position_embedding_base == 0:\n            if \"rope_theta\" in self.kwargs:\n                self.position_embedding_base = self.kwargs.pop(\"rope_theta\")\n            else:\n                self.position_embedding_base = 10000\n\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n        if (\n            self.pipeline_parallel_stages <= 0\n            or self.pipeline_parallel_stages > self.num_hidden_layers\n        ):\n            raise ValueError(\n                f'Invalid \"pipeline_parallel_stages\" value({self.pipeline_parallel_stages}). '\n            )\n\n        if self.clip_qkv is not None:\n            if self.clip_qkv <= 0:\n                raise ValueError(f\"'clip_qkv'({self.clip_qkv}) should be non-negative\")\n\n\nclass OLMoEmbedding(nn.Embedding):\n    \"\"\"The embedding module that can be shared with the final lm_head. From Qwen2Embedding.\"\"\"\n\n    def lm_head_forward(self, x: nn.Tensor):\n        \"\"\"The lm_head forwarding, which transposes the weight and multiplies\n        with the input tensor.\n        \"\"\"\n        weight = nn.op.permute_dims(self.weight)\n        return nn.op.matmul(x, weight, out_dtype=\"float32\")\n\n\nclass OLMoAttention(nn.Module):  # pylint: disable=missing-class-docstring\n    def __init__(self, config: OLMoConfig):\n        self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards\n        assert (\n            config.num_key_value_heads >= config.tensor_parallel_shards\n        ), f\"Too large tensor_parallel_shards, must be smaller than {config.num_key_value_heads}\"\n        assert (\n            config.num_key_value_heads % config.tensor_parallel_shards == 0\n        ), f\"num_kv_heads({config.num_key_value_heads}) must be divisible by tensor_parallel_shards\"\n        self.num_kv_heads = config.num_key_value_heads // config.tensor_parallel_shards\n        self.head_dim = config.head_dim\n        self.qkv_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim,\n            bias=False,\n        )\n        self.clip_qkv = config.clip_qkv\n        self.o_proj = nn.Linear(\n            in_features=self.num_q_heads * self.head_dim,\n            out_features=config.hidden_size,\n            bias=False,\n        )\n\n    def forward(  # pylint: disable=missing-function-docstring\n        self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int\n    ):\n        d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads\n        b, s, _ = hidden_states.shape\n\n        # QKV Projection\n        qkv = self.qkv_proj(hidden_states)\n\n        # Clamp after qkv projection if needed\n        if self.clip_qkv is not None:\n            qkv = qkv.maximum(-self.clip_qkv).minimum(self.clip_qkv)\n\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        # Attention\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_q_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        return self.o_proj(output)\n\n\n# Copied from qwen2_model.ACT2FN\nACT2FN = {\n    \"gelu\": partial(nn.gelu, approximate=False),\n    \"relu\": nn.relu,\n    \"silu\": nn.silu,\n    \"swish\": nn.silu,\n    \"gelu_new\": partial(nn.gelu, approximate=True),\n}\n\n\nclass OLMoFFN(nn.Module):  # pylint: disable=missing-class-docstring\n    def __init__(self, config: OLMoConfig):\n        super().__init__()\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.gate_up_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=2 * self.intermediate_size,\n            bias=False,\n        )\n        self.act_fn = ACT2FN[config.hidden_act]\n        self.down_proj = nn.Linear(\n            in_features=self.intermediate_size,\n            out_features=config.hidden_size,\n            bias=False,\n        )\n\n    def forward(self, x: Tensor):  # pylint: disable=missing-function-docstring\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(self.act_fn(x1) * x2)\n\n\n# pylint: disable=trailing-whitespace\n\n\nclass OLMoDecoderLayer(nn.Module):  # pylint: disable=missing-class-docstring\n    def __init__(self, config: OLMoConfig):\n        self.input_layernorm = nn.LayerNorm(\n            normalized_shape=config.hidden_size,\n            eps=1e-5,\n            elementwise_affine=False,\n        )\n        self.self_attn = OLMoAttention(config)\n        self.post_attention_layernorm = nn.LayerNorm(\n            normalized_shape=config.hidden_size,\n            eps=1e-5,\n            elementwise_affine=False,\n        )\n        self.mlp = OLMoFFN(config)\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.weight.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_q_heads * hd\n            k = self.self_attn.num_kv_heads * hd\n            v = self.self_attn.num_kv_heads * hd\n            i = self.mlp.intermediate_size\n            _set(\n                self.self_attn.qkv_proj,\n                tp.ShardSingleDim(\"_shard_qkv\", segs=[q, k, v], dim=0),\n            )\n            _set(self.self_attn.o_proj, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.gate_up_proj,\n                tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n            )\n            _set(self.mlp.down_proj, tp.ShardSingleDim(\"_shard_mlp_down\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n    def forward(  # pylint: disable=missing-function-docstring\n        self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int\n    ):\n        out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.mlp(self.post_attention_layernorm(hidden_states))\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n\nclass OLMoModel(nn.Module):  # pylint: disable=missing-class-docstring\n    def __init__(self, config: OLMoConfig):\n        assert config.hidden_size % config.num_attention_heads == 0\n        self.embed_tokens = OLMoEmbedding(config.vocab_size, config.hidden_size)\n        self.layers = nn.ModuleList(\n            [OLMoDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = nn.LayerNorm(\n            normalized_shape=config.hidden_size,\n            eps=1e-5,\n            elementwise_affine=False,\n        )\n\n        self.num_layers_per_stage = (\n            config.num_hidden_layers + config.pipeline_parallel_stages - 1\n        ) // config.pipeline_parallel_stages\n        # Compute pipeline layer partition.\n        layers_per_stage = (\n            config.num_hidden_layers + config.pipeline_parallel_stages - 1\n        ) // config.pipeline_parallel_stages\n        self.layer_partition = [\n            i * layers_per_stage for i in range(config.pipeline_parallel_stages)\n        ] + [config.num_hidden_layers]\n\n    def forward(  # pylint: disable=missing-function-docstring\n        self, inputs: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        hidden_states = inputs\n        for layer_id, layer in enumerate(self.layers):\n            if layer_id != 0 and layer_id in self.layer_partition:\n                hidden_states = op_ext.pipeline_stage_boundary(hidden_states)\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass OLMoForCausalLM(  # pylint: disable=missing-class-docstring,too-many-instance-attributes\n    nn.Module\n):\n    def __init__(self, config: OLMoConfig):\n        self.model = OLMoModel(config)\n        self.tie_word_embeddings = config.tie_word_embeddings\n        if not config.tie_word_embeddings:\n            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.vocab_size = config.vocab_size\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.head_dim = config.head_dim\n        self.rope_theta = config.position_embedding_base\n        self.rope_scaling = config.rope_scaling\n        self.intermediate_size = config.intermediate_size\n        self.num_hidden_layers = config.num_hidden_layers\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.dtype = \"float32\"\n\n        def _set_pp():\n            # hidden layers\n            for layer_id in range(config.num_hidden_layers):\n                stage = layer_id // (config.num_hidden_layers // config.pipeline_parallel_stages)\n                for _, param in self.model.layers[layer_id].named_parameters():\n                    param.attrs[\"pipeline_stages\"] = [stage]\n\n            # embedding table and lm_head is required by all stages\n            all_stages = list(range(config.pipeline_parallel_stages))\n            self.model.embed_tokens.weight.attrs[\"pipeline_stages\"] = all_stages\n            if not config.tie_word_embeddings:\n                self.lm_head.weight.attrs[\"pipeline_stages\"] = all_stages\n\n        _set_pp()\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(  # pylint: disable=missing-function-docstring\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            if self.tensor_parallel_shards > 1:\n                logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        return self.get_logits(hidden_states)\n\n    def batch_forward_to_last_hidden_states(  # pylint: disable=missing-function-docstring\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        op_ext.configure()\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        return hidden_states\n\n    def embed(self, input_ids: Tensor):  # pylint: disable=missing-function-docstring\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def get_logits(self, hidden_states: Tensor):  # pylint: disable=missing-function-docstring\n        op_ext.configure()\n        if self.tie_word_embeddings:\n            logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def batch_select_last_hidden_states(  # pylint: disable=missing-function-docstring\n        self, hidden_states: Tensor, logit_positions: Tensor\n    ):\n        op_ext.configure()\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        hidden_states = op.take(hidden_states, logit_positions, axis=0)\n        return hidden_states\n\n    def prefill(  # pylint: disable=missing-function-docstring\n        self, input_embed: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # get tensor of the last sequence\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k])\n\n        # pylint: disable=trailing-whitespace\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.get_logits(hidden_states)\n        return logits, paged_kv_cache\n\n    # pylint: disable=trailing-whitespace\n    def decode(  # pylint: disable=missing-function-docstring\n        self, input_embed: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        op_ext.configure()\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.get_logits(hidden_states)\n        return logits, paged_kv_cache\n\n    def prefill_to_last_hidden_states(  # pylint: disable=missing-function-docstring\n        self, input_embed: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        op_ext.configure()\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def decode_to_last_hidden_states(  # pylint: disable=missing-function-docstring\n        self, input_embed: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        op_ext.configure()\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def batch_prefill(  # pylint: disable=missing-function-docstring\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(  # pylint: disable=missing-function-docstring\n        self, input_embeds: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(  # pylint: disable=missing-function-docstring\n        self, input_embeds: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_prefill_to_last_hidden_states(  # pylint: disable=missing-function-docstring\n        self, input_embeds: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def batch_decode_to_last_hidden_states(  # pylint: disable=missing-function-docstring\n        self, input_embeds: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def batch_verify_to_last_hidden_states(  # pylint: disable=missing-function-docstring\n        self, input_embeds: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        hidden_states = self.batch_forward_to_last_hidden_states(input_embeds, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=missing-function-docstring,too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            rope_scaling=self.rope_scaling,\n            layer_partition=self.model.layer_partition,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):  # pylint: disable=missing-function-docstring\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"get_logits\": {\n                \"hidden_states\": nn.spec.Tensor([\"seq_len\", self.hidden_size], self.dtype),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_select_last_hidden_states\": {\n                \"hidden_states\": nn.spec.Tensor([\"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill_to_last_hidden_states\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode_to_last_hidden_states\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill_to_last_hidden_states\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode_to_last_hidden_states\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify_to_last_hidden_states\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/orion/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/orion/orion_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Orion parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\n\nfrom .orion_model import OrionForCausalLM\n\nhuggingface = make_standard_hf_loader(\n    model_cls=OrionForCausalLM,\n    add_unused=[\"rotary_emb.inv_freq\"],\n)\n"
  },
  {
    "path": "python/mlc_llm/model/orion/orion_model.py",
    "content": "\"\"\"\nImplementation for Orion-14B architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass OrionConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Orion model.\"\"\"\n\n    hidden_size: int\n    intermediate_size: int\n    num_attention_heads: int\n    num_hidden_layers: int\n    rms_norm_eps: float\n    vocab_size: int\n    position_embedding_base: int = 0\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    num_key_value_heads: int = 0\n    head_dim: int = 0\n    tensor_parallel_shards: int = 1\n    max_batch_size: int = 1\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.position_embedding_base == 0:\n            if \"rope_theta\" in self.kwargs:\n                self.position_embedding_base = self.kwargs.pop(\"rope_theta\")\n            else:\n                self.position_embedding_base = 10000\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.num_key_value_heads == 0:\n            self.num_key_value_heads = self.num_attention_heads\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.head_dim * self.num_attention_heads == self.hidden_size\n        assert self.num_attention_heads % self.num_key_value_heads == 0\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass OrionFFN(nn.Module):\n    def __init__(self, config: OrionConfig):\n        super().__init__()\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.gate_up_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=2 * self.intermediate_size,\n            bias=False,\n        )\n        self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)\n\n    def forward(self, x: Tensor):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(op.silu(x1) * x2)\n\n\nclass OrionAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: OrionConfig):\n        self.head_dim = config.head_dim\n        self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards\n        assert (\n            config.num_key_value_heads % config.tensor_parallel_shards == 0\n        ), f\"num_kv_heads({config.num_key_value_heads}) must be divisible by tensor_parallel_shards\"\n        assert (\n            config.num_key_value_heads >= config.tensor_parallel_shards\n        ), f\"Too large tensor_parallel_shards, must be smaller than {config.num_key_value_heads}\"\n        self.num_kv_heads = config.num_key_value_heads // config.tensor_parallel_shards\n        self.qkv_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=(self.num_q_heads + 2 * self.num_kv_heads) * self.head_dim,\n            bias=False,\n        )\n        self.o_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False)\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_kv_heads\n        b, s, _ = hidden_states.shape\n        # QKV Projection\n        qkv = self.qkv_proj(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        # Attention\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_q_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        return self.o_proj(output)\n\n\nclass OrionDecoderLayer(nn.Module):\n    def __init__(self, config: OrionConfig):\n        rms_norm_eps = config.rms_norm_eps\n        self.self_attn = OrionAttention(config)\n        self.mlp = OrionFFN(config)\n        self.input_layernorm = nn.LayerNorm(config.hidden_size, rms_norm_eps)\n        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, rms_norm_eps)\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.weight.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_q_heads * hd\n            k = self.self_attn.num_kv_heads * hd\n            v = self.self_attn.num_kv_heads * hd\n            i = self.mlp.intermediate_size\n            _set(\n                self.self_attn.qkv_proj,\n                tp.ShardSingleDim(\"_shard_qkv\", segs=[q, k, v], dim=0),\n            )\n            _set(self.self_attn.o_proj, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.gate_up_proj,\n                tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n            )\n            _set(self.mlp.down_proj, tp.ShardSingleDim(\"_shard_mlp_down\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.mlp(self.post_attention_layernorm(hidden_states))\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass OrionModel(nn.Module):\n    def __init__(self, config: OrionConfig):\n        assert config.hidden_size % config.num_attention_heads == 0\n        self.embed_tokens = nn.Embedding(\"vocab_size\", config.hidden_size)\n        self.layers = nn.ModuleList(\n            [OrionDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = nn.LayerNorm(config.hidden_size, config.rms_norm_eps)\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n\n    def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = input_embed\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass OrionForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: OrionConfig):\n        self.model = OrionModel(config)\n        self.lm_head = nn.Linear(config.hidden_size, \"vocab_size\", bias=False)\n        self.num_hidden_layers = config.num_hidden_layers\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.head_dim = config.head_dim\n        self.hidden_size = config.hidden_size\n        self.vocab_size = config.vocab_size\n        self.rope_theta = config.position_embedding_base\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/phi/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/phi/phi_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Phi parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.quantization import Quantization\n\nfrom .phi_model import Phi1Config, PhiConfig, PhiForCausalLM\n\n\ndef huggingface(model_config: PhiConfig, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : PhiConfig\n        The configuration of the Phi model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = PhiForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params = model.export_tvm(  # pylint: disable=W0632:unbalanced-tuple-unpacking\n        spec=model.get_default_spec()\n    )\n    named_parameters = dict(_named_params)\n    mapping = ExternMapping()\n\n    def _add(mlc_name, hf_name):\n        mapping.add_mapping(\n            mlc_name,\n            [hf_name],\n            functools.partial(\n                lambda x, dtype: x.astype(dtype),\n                dtype=named_parameters[mlc_name].dtype,\n            ),\n        )\n\n    if model_config.model_type == \"mixformer-sequential\":\n        _add(\"transformer.embd.weight\", \"layers.0.wte.weight\")\n        prefix = \"transformer.h\"\n        for i in range(model_config.n_layer):\n            _add(f\"{prefix}.{i}.ln.weight\", f\"layers.{i + 1}.ln.weight\")\n            _add(f\"{prefix}.{i}.ln.bias\", f\"layers.{i + 1}.ln.bias\")\n            _add(f\"{prefix}.{i}.mixer.Wqkv.weight\", f\"layers.{i + 1}.mixer.Wqkv.weight\")\n            _add(f\"{prefix}.{i}.mixer.Wqkv.bias\", f\"layers.{i + 1}.mixer.Wqkv.bias\")\n            _add(\n                f\"{prefix}.{i}.mixer.out_proj.weight\",\n                f\"layers.{i + 1}.mixer.out_proj.weight\",\n            )\n            _add(\n                f\"{prefix}.{i}.mixer.out_proj.bias\",\n                f\"layers.{i + 1}.mixer.out_proj.bias\",\n            )\n            _add(f\"{prefix}.{i}.mlp.fc1.weight\", f\"layers.{i + 1}.mlp.fc1.weight\")\n            _add(f\"{prefix}.{i}.mlp.fc1.bias\", f\"layers.{i + 1}.mlp.fc1.bias\")\n            _add(f\"{prefix}.{i}.mlp.fc2.weight\", f\"layers.{i + 1}.mlp.fc2.weight\")\n            _add(f\"{prefix}.{i}.mlp.fc2.bias\", f\"layers.{i + 1}.mlp.fc2.bias\")\n            mapping.add_unused(f\"layers.{i + 1}.mixer.rotary_emb.inv_freq\")\n        prefix = f\"layers.{model_config.n_layer + 1}\"\n        _add(\"lm_head.ln.weight\", f\"{prefix}.ln.weight\")\n        _add(\"lm_head.ln.bias\", f\"{prefix}.ln.bias\")\n        _add(\"lm_head.linear.weight\", f\"{prefix}.linear.weight\")\n        _add(\"lm_head.linear.bias\", f\"{prefix}.linear.bias\")\n\n    elif model_config.model_type == \"phi-msft\":\n        _add(\"transformer.embd.weight\", \"transformer.embd.wte.weight\")\n        for mlc_name, _ in named_parameters.items():\n            if mlc_name not in mapping.param_map:\n                _add(mlc_name, mlc_name)\n    return mapping\n\n\ndef phi1_huggingface(model_config: Phi1Config, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of Phi-1/Phi-1.5 HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : PhiConfig\n        The configuration of the Phi model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = PhiForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params = model.export_tvm(  # pylint: disable=W0632:unbalanced-tuple-unpacking\n        spec=model.get_default_spec()\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    def _add(mlc_name, hf_name):\n        mapping.add_mapping(\n            mlc_name,\n            [hf_name],\n            functools.partial(\n                lambda x, dtype: x.astype(dtype),\n                dtype=named_parameters[mlc_name].dtype,\n            ),\n        )\n\n    def _concat_add(mlc_name, hf_names):\n        mapping.add_mapping(\n            mlc_name,\n            hf_names,\n            functools.partial(\n                lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),\n                dtype=named_parameters[mlc_name].dtype,\n            ),\n        )\n\n    _add(\"lm_head.linear.weight\", \"lm_head.weight\")\n    _add(\"lm_head.linear.bias\", \"lm_head.bias\")\n    _add(\"lm_head.ln.weight\", \"model.final_layernorm.weight\")\n    _add(\"lm_head.ln.bias\", \"model.final_layernorm.bias\")\n    _add(\"transformer.embd.weight\", \"model.embed_tokens.weight\")\n\n    prefix = \"transformer.h\"\n    hf_prefix = \"model.layers\"\n    for i in range(model_config.num_hidden_layers):\n        _add(f\"{prefix}.{i}.ln.weight\", f\"{hf_prefix}.{i}.input_layernorm.weight\")\n        _add(f\"{prefix}.{i}.ln.bias\", f\"{hf_prefix}.{i}.input_layernorm.bias\")\n        _concat_add(\n            f\"{prefix}.{i}.mixer.Wqkv.weight\",\n            [\n                f\"{hf_prefix}.{i}.self_attn.q_proj.weight\",\n                f\"{hf_prefix}.{i}.self_attn.k_proj.weight\",\n                f\"{hf_prefix}.{i}.self_attn.v_proj.weight\",\n            ],\n        )\n        _concat_add(\n            f\"{prefix}.{i}.mixer.Wqkv.bias\",\n            [\n                f\"{hf_prefix}.{i}.self_attn.q_proj.bias\",\n                f\"{hf_prefix}.{i}.self_attn.k_proj.bias\",\n                f\"{hf_prefix}.{i}.self_attn.v_proj.bias\",\n            ],\n        )\n        _add(\n            f\"{prefix}.{i}.mixer.out_proj.weight\",\n            f\"{hf_prefix}.{i}.self_attn.dense.weight\",\n        )\n        _add(f\"{prefix}.{i}.mixer.out_proj.bias\", f\"{hf_prefix}.{i}.self_attn.dense.bias\")\n        _add(f\"{prefix}.{i}.mlp.fc1.weight\", f\"{hf_prefix}.{i}.mlp.fc1.weight\")\n        _add(f\"{prefix}.{i}.mlp.fc1.bias\", f\"{hf_prefix}.{i}.mlp.fc1.bias\")\n        _add(f\"{prefix}.{i}.mlp.fc2.weight\", f\"{hf_prefix}.{i}.mlp.fc2.weight\")\n        _add(f\"{prefix}.{i}.mlp.fc2.bias\", f\"{hf_prefix}.{i}.mlp.fc2.bias\")\n        mapping.add_unused(f\"{hf_prefix}.{i}.mixer.rotary_emb.inv_freq\")\n\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/phi/phi_model.py",
    "content": "\"\"\"\nImplementation for Phi architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional, Union\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass Phi1Config(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Phi-1/Phi-1.5 model.\"\"\"\n\n    vocab_size: int = 51200\n    hidden_size: int = 2048\n    intermediate_size: int = 8192\n    num_hidden_layers: int = 24\n    num_attention_heads: int = 32\n    layer_norm_eps: float = 1e-5\n    position_embedding_base: int = 0\n    partial_rotary_factor: float = 0.5\n    num_key_value_heads: int = 0\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    head_dim: int = 0\n    tensor_parallel_shards: int = 1\n    max_batch_size: int = 1\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.position_embedding_base == 0:\n            if \"rope_theta\" in self.kwargs:\n                self.position_embedding_base = self.kwargs.pop(\"rope_theta\")\n            else:\n                self.position_embedding_base = 10000\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        if self.num_key_value_heads == 0 or self.num_key_value_heads is None:\n            self.num_key_value_heads = self.num_attention_heads\n        if self.intermediate_size == 0 or self.intermediate_size is None:\n            self.intermediate_size = 4 * self.hidden_size\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.head_dim * self.num_attention_heads == self.hidden_size\n        assert self.num_attention_heads % self.num_key_value_heads == 0\n\n\n@dataclasses.dataclass\nclass PhiConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Phi-2 model.\"\"\"\n\n    model_type: str  # \"phi\", \"phi-msft\", \"mixformer-sequential\"\n    vocab_size: int = 51200\n    n_positions: int = 2048\n    n_embd: int = 2560\n    n_layer: int = 32\n    n_inner: int = 0\n    n_head: int = 32\n    rotary_dim: int = 32\n    position_embedding_base: int = 0\n    layer_norm_epsilon: float = 1e-5\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    n_head_kv: int = 0\n    head_dim: int = 0\n    tensor_parallel_shards: int = 1\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.position_embedding_base == 0:\n            if \"rope_theta\" in self.kwargs:\n                self.position_embedding_base = self.kwargs.pop(\"rope_theta\")\n            else:\n                self.position_embedding_base = 10000\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                self.context_window_size = self.n_positions\n                logger.info(\n                    \"%s not found in config.json. Falling back to %s (%d)\",\n                    bold(\"context_window_size\"),\n                    \"n_positions\",\n                    self.context_window_size,\n                )\n        if self.prefill_chunk_size == 0:\n            self.prefill_chunk_size = self.context_window_size\n        self.prefill_chunk_size = min(self.prefill_chunk_size, self.context_window_size)\n        if self.n_head_kv == 0 or self.n_head_kv is None:\n            self.n_head_kv = self.n_head\n        if self.n_inner == 0 or self.n_inner is None:\n            self.n_inner = 4 * self.n_embd\n        if self.head_dim == 0:\n            self.head_dim = self.n_embd // self.n_head\n        assert self.head_dim * self.n_head == self.n_embd\n        assert self.n_head % self.n_head_kv == 0\n\n    @staticmethod\n    def from_phi1(config: Phi1Config) -> \"PhiConfig\":\n        \"Build PhiConig from a Phi1Config.\"\n        return PhiConfig(\n            model_type=\"phi\",\n            vocab_size=config.vocab_size,\n            n_positions=config.context_window_size,\n            n_embd=config.hidden_size,\n            n_layer=config.num_hidden_layers,\n            n_inner=config.intermediate_size,\n            n_head=config.num_attention_heads,\n            rotary_dim=int(config.partial_rotary_factor * config.head_dim),\n            position_embedding_base=config.position_embedding_base,\n            layer_norm_epsilon=config.layer_norm_eps,\n            context_window_size=config.context_window_size,\n            prefill_chunk_size=config.prefill_chunk_size,\n            n_head_kv=config.num_key_value_heads,\n            head_dim=config.head_dim,\n            tensor_parallel_shards=config.tensor_parallel_shards,\n            kwargs=config.kwargs,\n        )\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass PhiMLP(nn.Module):\n    def __init__(self, config: PhiConfig):\n        super().__init__()\n        if config.n_inner % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.n_inner} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.n_inner // config.tensor_parallel_shards\n        self.fc1 = nn.Linear(config.n_embd, self.intermediate_size)\n        self.fc2 = nn.Linear(self.intermediate_size, config.n_embd)\n\n    def forward(self, hidden_states: Tensor):\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = op.gelu(hidden_states, approximate=\"tanh\")\n        hidden_states = self.fc2(hidden_states)\n\n        return hidden_states\n\n\nclass PhiMHA(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: PhiConfig):\n        self.num_q_heads = config.n_head // config.tensor_parallel_shards\n        assert (\n            config.n_head % config.tensor_parallel_shards == 0\n        ), f\"n_head({config.n_head}) must be divisible by tensor_parallel_shards\"\n        self.n_head_kv = config.n_head_kv // config.tensor_parallel_shards\n        assert (\n            config.n_head_kv % config.tensor_parallel_shards == 0\n        ), f\"n_head({config.n_head_kv}) must be divisible by tensor_parallel_shards\"\n        self.head_dim = config.head_dim\n        op_size = self.head_dim * (self.num_q_heads + 2 * self.n_head_kv)\n        hidden_size = config.n_embd\n\n        self.Wqkv = nn.Linear(hidden_size, op_size, bias=True)\n        self.out_proj = nn.Linear(self.num_q_heads * self.head_dim, hidden_size, bias=True)\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q, h_kv = self.head_dim, self.num_q_heads, self.n_head_kv\n        b, s, _ = hidden_states.shape\n        # QKV Projection\n        qkv = self.Wqkv(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        # Attention\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_q_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        return self.out_proj(output)\n\n\nclass PhiParallelBlock(nn.Module):\n    def __init__(self, config: PhiConfig):\n        super().__init__()\n\n        self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)\n        self.mixer = PhiMHA(config)\n        self.mlp = PhiMLP(config)\n\n        def _set_tp():\n            def _set(param, hint):\n                param.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.mixer.num_q_heads * hd\n            k = self.mixer.n_head_kv * hd\n            v = self.mixer.n_head_kv * hd\n            _set(\n                self.mixer.Wqkv.weight,\n                tp.ShardSingleDim(\"_shard_qkv_weight\", segs=[q, k, v], dim=0),\n            )\n            _set(\n                self.mixer.Wqkv.bias,\n                tp.ShardSingleDim(\"_shard_qkv_bias\", segs=[q, k, v], dim=0),\n            )\n            _set(self.mixer.out_proj.weight, tp.ShardSingleDim(\"_shard_o_weight\", dim=1))\n            _set(self.mlp.fc1.weight, tp.ShardSingleDim(\"_shard_mlp_fc1_weight\", dim=0))\n            _set(self.mlp.fc1.bias, tp.ShardSingleDim(\"_shard_mlp_fc1_bias\", dim=0))\n            _set(self.mlp.fc2.weight, tp.ShardSingleDim(\"_shard_mlp_fc2_weight\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        residual = hidden_states\n        hidden_states = self.ln(hidden_states)\n\n        with (\n            tp.shard_bias(self.mixer.out_proj, self.tensor_parallel_shards),\n            tp.shard_bias(self.mlp.fc2, self.tensor_parallel_shards),\n        ):\n            attn_outputs = self.mixer(hidden_states, paged_kv_cache, layer_id)\n            feed_forward_hidden_states = self.mlp(hidden_states)\n\n        hidden_states = self._apply_parallel_residual(\n            attn_outputs, feed_forward_hidden_states, residual\n        )\n\n        return hidden_states\n\n    def _apply_parallel_residual(self, attn_out, mlp_out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(\n                attn_out + mlp_out + residual / self.tensor_parallel_shards, \"sum\"\n            )\n        return attn_out + mlp_out + residual\n\n\nclass PhiCausalLMHead(nn.Module):\n    def __init__(self, config: PhiConfig) -> None:\n        super().__init__()\n\n        self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)\n        self.linear = nn.Linear(config.n_embd, \"vocab_size\")\n\n    def forward(self, hidden_states: Tensor):\n        hidden_states = self.ln(hidden_states)\n        logits = self.linear(hidden_states)\n\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n\nclass PhiModel(nn.Module):\n    def __init__(self, config: PhiConfig) -> None:\n        super().__init__()\n        self.embd = nn.Embedding(config.vocab_size, config.n_embd)\n        self.h = nn.ModuleList([PhiParallelBlock(config) for _ in range(config.n_layer)])\n\n    def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = input_embed\n        for layer_id, layer in enumerate(self.h):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n\n        return hidden_states\n\n\nclass PhiForCausalLM(nn.Module):\n    # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: Union[PhiConfig, Phi1Config]) -> None:\n        super().__init__()\n\n        if isinstance(config, Phi1Config):\n            config = PhiConfig.from_phi1(config)\n\n        self.transformer = PhiModel(config)\n        self.lm_head = PhiCausalLMHead(config)\n        self.num_hidden_layers = config.n_layer\n        self.num_attention_heads = config.n_head\n        self.num_key_value_heads = config.n_head_kv\n        self.head_dim = config.head_dim\n        self.hidden_size = config.n_embd\n        self.vocab_size = config.vocab_size\n        self.rope_theta = config.position_embedding_base\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.rotary_dim = config.rotary_dim\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.transformer(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        lm_logits = self.lm_head(hidden_states)\n        if lm_logits.dtype != \"float32\":\n            lm_logits = lm_logits.astype(\"float32\")\n        return lm_logits\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.transformer(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.lm_head(hidden_states)\n\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.transformer(input_embed, paged_kv_cache)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        embeds = self.transformer.embd(input_ids)\n        return embeds\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            rotary_dim=self.rotary_dim,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/phi3/__init__.py",
    "content": "\"\"\"Common `nn.Modules` used to define LLMs in this project.\"\"\"\n\nfrom .phi3_model import Phi3Model\n"
  },
  {
    "path": "python/mlc_llm/model/phi3/phi3_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Phi parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.quantization import Quantization\n\nfrom .phi3_model import Phi3Config, Phi3ForCausalLM\n\n\ndef phi3_huggingface(model_config: Phi3Config, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of Phi-1/Phi-1.5 HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : PhiConfig\n        The configuration of the Phi model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = Phi3ForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params = model.export_tvm(  # pylint: disable=W0632:unbalanced-tuple-unpacking\n        spec=model.get_default_spec()\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    def _add(mlc_name, hf_name):\n        mapping.add_mapping(\n            mlc_name,\n            [hf_name],\n            functools.partial(\n                lambda x, dtype: x.astype(dtype),\n                dtype=named_parameters[mlc_name].dtype,\n            ),\n        )\n\n    # Skip lm_head.weight if tie_word_embeddings is enabled\n    if not getattr(model_config, \"tie_word_embeddings\", False):\n        _add(\"lm_head.weight\", \"lm_head.weight\")\n    _add(\"transformer.norm.weight\", \"model.norm.weight\")\n    _add(\"transformer.embd.weight\", \"model.embed_tokens.weight\")\n\n    prefix = \"transformer.h\"\n    hf_prefix = \"model.layers\"\n    for i in range(model_config.num_hidden_layers):\n        _add(f\"{prefix}.{i}.ln.weight\", f\"{hf_prefix}.{i}.input_layernorm.weight\")\n        _add(\n            f\"{prefix}.{i}.mlp.down_proj.weight\",\n            f\"{hf_prefix}.{i}.mlp.down_proj.weight\",\n        )\n        _add(\n            f\"{prefix}.{i}.mlp.gate_up_proj.weight\",\n            f\"{hf_prefix}.{i}.mlp.gate_up_proj.weight\",\n        )\n        _add(\n            f\"{prefix}.{i}.post_attention_layernorm.weight\",\n            f\"{hf_prefix}.{i}.post_attention_layernorm.weight\",\n        )\n        _add(\n            f\"{prefix}.{i}.mixer.out_proj.weight\",\n            f\"{hf_prefix}.{i}.self_attn.o_proj.weight\",\n        )\n        _add(\n            f\"{prefix}.{i}.mixer.qkv_proj.weight\",\n            f\"{hf_prefix}.{i}.self_attn.qkv_proj.weight\",\n        )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/phi3/phi3_model.py",
    "content": "\"\"\"\nImplementation for Phi-3 architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass Phi3Config(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Phi-3 model.\"\"\"\n\n    model_type: str  # \"phi\", \"phi-msft\", \"mixformer-sequential\"\n    hidden_size: int\n    vocab_size: int\n    num_hidden_layers: int\n    num_attention_heads: int\n    intermediate_size: int\n    rms_norm_eps: float\n    num_key_value_heads: int\n    max_position_embeddings: int\n    position_embedding_base: int = 0\n    rope_scaling: Optional[Dict[str, Any]] = None\n    original_max_position_embeddings: int = 0\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    head_dim: int = 0\n    tensor_parallel_shards: int = 1\n    max_batch_size: int = 1\n    tie_word_embeddings: bool = False\n    partial_rotary_factor: float = 1.0\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.position_embedding_base == 0:\n            if \"rope_theta\" in self.kwargs:\n                self.position_embedding_base = self.kwargs.pop(\"rope_theta\")\n            else:\n                self.position_embedding_base = 10000\n        if self.rope_scaling is not None:\n            if \"type\" not in self.rope_scaling:\n                self.rope_scaling = None\n            else:\n                if self.rope_scaling[\"type\"] == \"su\":\n                    self.rope_scaling[\"type\"] = \"longrope\"\n\n                assert (\n                    self.rope_scaling[\"type\"] == \"longrope\"\n                ), f\"Unsupported RoPE scaling type {self.rope_scaling['rope_type']} for Phi3\"\n                self.rope_scaling[\"rope_type\"] = self.rope_scaling[\"type\"]\n                (\n                    self.rope_scaling[\"max_position_embeddings\"],\n                    self.rope_scaling[\"original_max_position_embeddings\"],\n                ) = (\n                    self.max_position_embeddings,\n                    self.original_max_position_embeddings,\n                )\n\n        if self.context_window_size == 0:\n            self.context_window_size = self.max_position_embeddings\n\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n        if self.num_key_value_heads == 0 or self.num_key_value_heads is None:\n            self.num_key_value_heads = self.num_attention_heads\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.head_dim * self.num_attention_heads == self.hidden_size\n        assert self.num_attention_heads % self.num_key_value_heads == 0\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass Phi3Embedding(nn.Embedding):\n    \"\"\"The embedding module that can be shared with the final lm_head.\"\"\"\n\n    def lm_head_forward(self, x: nn.Tensor):\n        \"\"\"The lm_head forwarding, which transposes the weight and multiplies\n        with the input tensor.\n        \"\"\"\n        weight = nn.op.permute_dims(self.weight)\n        return nn.op.matmul(x, weight, out_dtype=\"float32\")\n\n\nclass Phi3MLP(nn.Module):\n    def __init__(self, config: Phi3Config):\n        super().__init__()\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.gate_up_proj = nn.Linear(config.hidden_size, 2 * self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)\n\n    def forward(self, hidden_states: Tensor):\n        up_states = self.gate_up_proj(hidden_states)\n        gate, up_states = nn.op.split(up_states, 2, axis=-1)\n        up_states = up_states * op.silu(gate)\n        return self.down_proj(up_states)\n\n\nclass PhiMHA(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: Phi3Config):\n        self.num_q_heads = config.num_attention_heads // config.tensor_parallel_shards\n        assert config.num_attention_heads % config.tensor_parallel_shards == 0, (\n            f\"num_attention_heads({config.num_attention_heads}) \"\n            \"must be divisible by tensor_parallel_shards\"\n        )\n        self.num_key_value_heads = config.num_key_value_heads // config.tensor_parallel_shards\n        assert config.num_key_value_heads % config.tensor_parallel_shards == 0, (\n            f\"num_attention_heads({config.num_key_value_heads}) \"\n            \"must be divisible by tensor_parallel_shards\"\n        )\n        self.head_dim = config.head_dim\n\n        self.qkv_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=(self.num_q_heads + 2 * self.num_key_value_heads) * self.head_dim,\n            bias=False,\n        )\n        self.out_proj = nn.Linear(self.num_q_heads * self.head_dim, config.hidden_size, bias=False)\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q, h_kv = self.head_dim, self.num_q_heads, self.num_key_value_heads\n        b, s, _ = hidden_states.shape\n        # QKV Projection\n        qkv = self.qkv_proj(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        # Attention\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_q_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        return self.out_proj(output)\n\n\nclass Phi3ParallelBlock(nn.Module):\n    def __init__(self, config: Phi3Config):\n        super().__init__()\n\n        self.ln = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n        self.mixer = PhiMHA(config)\n        self.mlp = Phi3MLP(config)\n        self.post_attention_layernorm = nn.RMSNorm(\n            config.hidden_size, -1, config.rms_norm_eps, bias=False\n        )\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.weight.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.mixer.num_q_heads * hd\n            k = self.mixer.num_key_value_heads * hd\n            v = self.mixer.num_key_value_heads * hd\n            i = self.mlp.intermediate_size\n\n            _set(\n                self.mixer.qkv_proj,\n                tp.ShardSingleDim(\"_shard_qkv\", segs=[q, k, v], dim=0),\n            )\n            _set(self.mixer.out_proj, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.gate_up_proj,\n                tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n            )\n            _set(self.mlp.down_proj, tp.ShardSingleDim(\"_shard_mlp_down\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        attn_outputs = self.mixer(self.ln(hidden_states), paged_kv_cache, layer_id)\n        hidden_states = self._apply_parallel_residual(attn_outputs, hidden_states)\n        out = self.mlp(self.post_attention_layernorm(hidden_states))\n        hidden_states = self._apply_parallel_residual(out, hidden_states)\n        return hidden_states\n\n    def _apply_parallel_residual(self, mlp_out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(mlp_out + residual / self.tensor_parallel_shards, \"sum\")\n        return mlp_out + residual\n\n\nclass Phi3Model(nn.Module):\n    def __init__(self, config: Phi3Config) -> None:\n        super().__init__()\n        self.embd = Phi3Embedding(config.vocab_size, config.hidden_size)\n        self.h = nn.ModuleList([Phi3ParallelBlock(config) for _ in range(config.num_hidden_layers)])\n        self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n\n    def forward(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = input_embed\n        for layer_id, layer in enumerate(self.h):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass Phi3ForCausalLM(nn.Module):\n    # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: Phi3Config) -> None:\n        super().__init__()\n\n        self.transformer = Phi3Model(config)\n        self.tie_word_embeddings = config.tie_word_embeddings\n        if not config.tie_word_embeddings:\n            self.lm_head = nn.Linear(config.hidden_size, \"vocab_size\", bias=False)\n        self.num_hidden_layers = config.num_hidden_layers\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.head_dim = config.head_dim\n        self.hidden_size = config.hidden_size\n        self.vocab_size = config.vocab_size\n        self.rope_scaling = config.rope_scaling\n        self.rope_theta = config.position_embedding_base\n        self.rope_ext_factors = (\n            (config.rope_scaling[\"long_factor\"] + config.rope_scaling[\"short_factor\"])\n            if config.rope_scaling is not None\n            else None\n        )\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.partial_rotary_factor = config.partial_rotary_factor\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def get_logits(self, hidden_states: Tensor):\n        op_ext.configure()\n        if self.tie_word_embeddings:\n            logits = self.transformer.embd.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.transformer(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        return self.get_logits(hidden_states)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.transformer(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.get_logits(hidden_states)\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.transformer(input_embed, paged_kv_cache)\n        logits = self.get_logits(hidden_states)\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        embeds = self.transformer.embd(input_ids)\n        return embeds\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scaling=self.rope_scaling,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            rope_ext_factors=self.rope_ext_factors,\n            rotary_dim=int(self.head_dim * self.partial_rotary_factor),\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/phi3v/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/phi3v/phi3v_image.py",
    "content": "\"\"\"\nImplementation for Phi architecture.\n\"\"\"\n\nfrom tvm import relax, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Module, Tensor, op\nfrom tvm.script import tir as T\n\nfrom mlc_llm.model.vision import CLIPVisionModel\nfrom mlc_llm.support.config import ConfigBase\n\n\n# mypy: disable-error-code=\"attr-defined\"\n# pylint: disable=invalid-name,missing-docstring\nclass ImageProjection(Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: ConfigBase):\n        super().__init__()\n        self.linear_1 = nn.Linear(\n            config.vision_config.hidden_size * 4, config.hidden_size, bias=True\n        )\n        self.act = nn.GELU()\n        self.linear_2 = nn.Linear(config.hidden_size, config.hidden_size, bias=True)\n\n    def forward(self, image_features: Tensor) -> Tensor:\n        shape_1 = tir.Var(\"shape_1\", \"int64\")\n        image_features = op.wrap_nested(\n            relax.BlockBuilder()\n            .current()\n            .match_cast(\n                image_features._expr,  # pylint: disable=protected-access\n                relax.TensorStructInfo([shape_1, image_features.shape[1]], image_features.dtype),\n            ),\n            \"image_features\",\n        )\n\n        hidden_states = self.linear_1(image_features)\n\n        shape_2 = tir.Var(\"shape_2\", \"int64\")\n        hidden_states = op.wrap_nested(\n            relax.BlockBuilder()\n            .current()\n            .match_cast(\n                hidden_states._expr,  # pylint: disable=protected-access\n                relax.TensorStructInfo([shape_2, hidden_states.shape[1]], hidden_states.dtype),\n            ),\n            \"hidden_states\",\n        )\n\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.linear_2(hidden_states)\n        return hidden_states\n\n\nclass Phi3ImageEmbedding(Module):\n    def __init__(self, config: ConfigBase):\n        super().__init__()\n\n        self.img_processor = CLIPVisionModel(config.vision_config)\n        self.image_dim_out = config.img_processor[\"image_dim_out\"]\n\n        self.glb_GN = nn.Parameter((1, 1, self.image_dim_out * 4))\n        self.sub_GN = nn.Parameter((1, 1, 1, self.image_dim_out * 4))\n\n        self.img_projection = ImageProjection(config)\n        self.image_size = config.vision_config.image_size\n\n    # pylint: disable=dangerous-default-value\n    def apply_schedule(self, sch, block, bdx=32, tile=[32, 32]):\n        loop_x, loop_y = sch.get_loops(block)[-2:]\n        xo, xi = sch.split(loop_x, factors=[tile[0], None])\n        yo, yi = sch.split(loop_y, factors=[tile[1], None])\n        sch.reorder(xo, yo, xi, yi)\n        t = sch.fuse(xo, yo)\n        ty, tx = sch.split(t, factors=[None, bdx])\n        sch.bind(ty, \"threadIdx.y\")\n        sch.bind(tx, \"threadIdx.x\")\n\n    # pylint: disable=too-many-arguments,too-many-locals\n    def dyn_repeat_4d_tensor(self, input_tensor, r0, r1, r2, r3) -> Tensor:\n        assert 4 == input_tensor.ndim, \"input_tensor should be 4D data tensor\"\n\n        def create_dyn_repeat_func(dtype):\n            @T.prim_func\n            def dyn_repeat_4d_tensor_func(  # pylint disable=too-many-locals\n                input_tensor: T.handle,\n                output: T.handle,\n                ch0: T.int64(),\n                ch1: T.int64(),\n                ch2: T.int64(),\n                ch3: T.int64(),\n            ):\n                T.func_attr({\"op_pattern\": 8, \"tir.noalias\": True, \"tir.is_scheduled\": 1})\n                n, c, h, w = T.int64(), T.int64(), T.int64(), T.int64()\n                input_tensor_buf = T.match_buffer(input_tensor, (n, c, h, w), dtype=dtype)\n                out_buf = T.match_buffer(output, (n * ch0, c * ch1, h * ch2, w * ch3), dtype=dtype)\n\n                for n_idx in T.thread_binding(n * ch0, thread=\"blockIdx.x\"):\n                    for c_idx in T.thread_binding(c * ch1, thread=\"blockIdx.y\"):\n                        for h_idx, w_idx in T.grid(h * ch2, w * ch3):\n                            with T.sblock(\"dyn_repeat_4d_tensor\"):\n                                T.reads(input_tensor_buf[n_idx, c_idx, h_idx, w_idx])\n                                T.writes(out_buf[n_idx, c_idx, h_idx, w_idx])\n                                out_buf[n_idx, c_idx, h_idx, w_idx] = input_tensor_buf[\n                                    n_idx % n, c_idx % c, h_idx % h, w_idx % w\n                                ]\n\n            return dyn_repeat_4d_tensor_func\n\n        n, c, h, w = input_tensor.shape\n        out = op.tensor_ir_op(\n            create_dyn_repeat_func(input_tensor.dtype),\n            \"dyn_repeat_4d_tensor\",\n            [input_tensor, r0, r1, r2, r3],\n            [Tensor.placeholder([n * r0, c * r1, h * r2, w * r3], input_tensor.dtype)],\n        )\n        return out\n\n    def dyn_concate_dim_2(self, input_1, input_2) -> Tensor:\n        def create_dyn_concate_func(dtype):\n            @T.prim_func\n            def dyn_concate_dim_2_func(input_1: T.handle, input_2: T.handle, output: T.handle):\n                T.func_attr({\"op_pattern\": 8, \"tir.noalias\": True, \"tir.is_scheduled\": 1})\n                n, c, h1, h2, w = T.int64(), T.int64(), T.int64(), T.int64(), T.int64()\n                input_1_buf = T.match_buffer(input_1, (n, c, h1, w), dtype=dtype)\n                input_2_buf = T.match_buffer(input_2, (n, c, h2, w), dtype=dtype)\n                out_buf = T.match_buffer(output, (n, c, h1 + h2, w), dtype=dtype)\n\n                for n_idx in T.thread_binding(n, thread=\"blockIdx.x\"):\n                    for c_idx in T.thread_binding(c, thread=\"blockIdx.y\"):\n                        for h_idx, w_idx in T.grid(h1 + h2, w):\n                            with T.sblock(\"dyn_concate_dim_2\"):\n                                T.reads(input_1_buf[n_idx, c_idx, h_idx, w_idx])\n                                T.writes(out_buf[n_idx, c_idx, h_idx, w_idx])\n                                if h_idx < h1:\n                                    out_buf[n_idx, c_idx, h_idx, w_idx] = input_1_buf[\n                                        n_idx, c_idx, h_idx, w_idx\n                                    ]\n                                else:\n                                    out_buf[n_idx, c_idx, h_idx, w_idx] = input_2_buf[\n                                        n_idx, c_idx, h_idx - h1, w_idx\n                                    ]\n\n            return dyn_concate_dim_2_func\n\n        n1, c1, h1, w1 = input_1.shape\n        n2, c2, h2, w2 = input_2.shape\n        assert n1 == n2 and c1 == c2 and w1 == w2\n\n        out = op.tensor_ir_op(\n            create_dyn_concate_func(input_1.dtype),\n            \"dyn_concate_dim_2\",\n            [input_1, input_2],\n            [Tensor.placeholder([n1, c1, h1 + h2, w1], input_1.dtype)],\n        )\n        return out\n\n    def dyn_concate_dim_1(self, input_1, input_2) -> Tensor:\n        def create_dyn_concate_func(dtype):\n            @T.prim_func\n            def dyn_concate_dim_1_func(input_1: T.handle, input_2: T.handle, output: T.handle):\n                T.func_attr({\"op_pattern\": 8, \"tir.noalias\": True, \"tir.is_scheduled\": 1})\n                c, h1, h2, w = T.int64(), T.int64(), T.int64(), T.int64()\n                input_1_buf = T.match_buffer(input_1, (c, h1, w), dtype=dtype)\n                input_2_buf = T.match_buffer(input_2, (c, h2, w), dtype=dtype)\n                out_buf = T.match_buffer(output, (c, h1 + h2, w), dtype=dtype)\n\n                for c_idx in T.thread_binding(c, thread=\"blockIdx.y\"):\n                    for h_idx, w_idx in T.grid(h1 + h2, w):\n                        with T.sblock(\"dyn_concate_dim_1\"):\n                            T.reads(input_1_buf[c_idx, h_idx, w_idx])\n                            T.writes(out_buf[c_idx, h_idx, w_idx])\n                            if h_idx < h1:\n                                out_buf[c_idx, h_idx, w_idx] = input_1_buf[c_idx, h_idx, w_idx]\n                            else:\n                                out_buf[c_idx, h_idx, w_idx] = input_2_buf[c_idx, h_idx - h1, w_idx]\n\n            return dyn_concate_dim_1_func\n\n        c1, h1, w1 = input_1.shape\n        c2, h2, w2 = input_2.shape\n        assert c1 == c2 and w1 == w2\n\n        out = op.tensor_ir_op(\n            create_dyn_concate_func(input_1.dtype),\n            \"dyn_concate\",\n            [input_1, input_2],\n            [Tensor.placeholder([c1, h1 + h2, w1], input_1.dtype)],\n        )\n        return out\n\n    def get_img_features(self, img_embeds: Tensor) -> Tensor:\n        img_processor_output = self.img_processor(img_embeds)\n        patch_feature = nn.op.split(img_processor_output, indices_or_sections=[1], axis=1)\n        return patch_feature[1]\n\n    def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):\n        N, L, C = image_features.shape\n        num_images = 1\n        H = int(L**0.5)\n        image_features = nn.op.reshape(image_features, ([N, H, H, C]))  # N, 24, 24, 1024\n        image_features = nn.op.reshape(\n            image_features, ([N, H // 2, 2, H // 2, 2, C])\n        )  # N, 12, 2, 12, 2, 1024\n\n        new_s1 = tir.Var(\"new_s1\", \"int64\")\n        new_s2 = tir.Var(\"new_s2\", \"int64\")\n\n        image_features = op.wrap_nested(\n            relax.BlockBuilder()\n            .current()\n            .match_cast(\n                image_features._expr,  # pylint: disable=protected-access\n                relax.TensorStructInfo(\n                    [\n                        image_features.shape[0],\n                        new_s1,\n                        image_features.shape[2],\n                        new_s2,\n                        image_features.shape[4],\n                        image_features.shape[5],\n                    ],\n                    image_features.dtype,\n                ),\n            ),\n            \"image_features_1\",\n        )\n\n        image_features = nn.op.permute_dims(\n            image_features, axes=([0, 1, 3, 2, 4, 5])\n        )  # N, 12, 12, 2, 2, 1024\n        image_features = nn.op.reshape(image_features, ([N, -1, 4 * C]))  # N, 144, 4096\n        image_features = nn.op.reshape(\n            image_features, ([num_images, h_crop, w_crop, H // 2, H // 2, -1])\n        )\n\n        new_s3 = tir.Var(\"new_s3\", \"int64\")\n        new_s4 = tir.Var(\"new_s4\", \"int64\")\n\n        image_features = op.wrap_nested(\n            relax.BlockBuilder()\n            .current()\n            .match_cast(\n                image_features._expr,  # pylint: disable=protected-access\n                relax.TensorStructInfo(\n                    [\n                        image_features.shape[0],\n                        new_s3,\n                        image_features.shape[2],\n                        new_s4,\n                        image_features.shape[4],\n                        image_features.shape[5],\n                    ],\n                    image_features.dtype,\n                ),\n            ),\n            \"image_features_2\",\n        )\n\n        image_features = nn.op.permute_dims(image_features, axes=([0, 1, 3, 2, 4, 5]))\n        image_features_hd = nn.op.reshape(\n            image_features, ([num_images, h_crop * H // 2, w_crop * H // 2, 4 * C])\n        )\n\n        return image_features_hd\n\n    def add_image_newline(self, image_features_hd):\n        \"\"\"\n        image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)\n        output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)\n        \"\"\"\n        num_images, h, w, hid_dim = image_features_hd.shape  # pylint: disable=unused-variable\n        temp_sub_GN = self.dyn_repeat_4d_tensor(\n            self.sub_GN, T.int64(1), T.int64(h), T.int64(1), T.int64(1)\n        )\n        image_features_hd_newline = self.dyn_concate_dim_2(image_features_hd, temp_sub_GN)\n        image_features_hd_newline = nn.op.reshape(\n            image_features_hd_newline, ([num_images, -1, hid_dim])\n        )\n        return image_features_hd_newline\n\n    # pylint: disable=too-many-locals,too-many-locals,unused-argument\n    def forward(self, pixel_values: Tensor, h_crop, w_crop) -> Tensor:\n        img_features = self.get_img_features(pixel_values)\n        img_features = nn.op.split(img_features, indices_or_sections=[1], axis=0)\n\n        global_image_features = img_features[0]\n        global_image_features_hd = self.reshape_hd_patches_2x2merge(global_image_features, 1, 1)\n        global_image_features_hd_newline = self.add_image_newline(global_image_features_hd)\n\n        sub_image_features = img_features[1]\n        sub_image_features_hd = self.reshape_hd_patches_2x2merge(sub_image_features, h_crop, w_crop)\n        sub_image_features_hd_newline = self.add_image_newline(sub_image_features_hd)\n\n        global_image_features_hd = nn.op.squeeze(global_image_features_hd_newline, 0)\n\n        combined_image = self.dyn_concate_dim_1(sub_image_features_hd_newline, self.glb_GN)\n        combined_image = self.dyn_concate_dim_1(combined_image, global_image_features_hd_newline)\n        combined_image = nn.op.squeeze(combined_image, 0)\n\n        new_s7 = tir.Var(\"new_s7\", \"int64\")\n\n        combined_image = op.wrap_nested(\n            relax.BlockBuilder()\n            .current()\n            .match_cast(\n                combined_image._expr,  # pylint: disable=protected-access\n                relax.TensorStructInfo([new_s7, combined_image.shape[1]], combined_image.dtype),\n            ),\n            \"combined_image\",\n        )\n        output_image = self.img_projection(combined_image)\n        return output_image\n"
  },
  {
    "path": "python/mlc_llm/model/phi3v/phi3v_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Phi parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.quantization import Quantization\n\nfrom .phi3v_model import Phi3VConfig, Phi3VForCausalLM\n\n\n# pylint: disable=too-many-statements\ndef huggingface(model_config: Phi3VConfig, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of Phi-1/Phi-1.5 HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : PhiConfig\n        The configuration of the Phi model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = Phi3VForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params = model.export_tvm(  # pylint: disable=W0632:unbalanced-tuple-unpacking\n        spec=model.get_default_spec()\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    def _add(mlc_name, hf_name=None):\n        if None is hf_name:\n            hf_name = mlc_name\n\n        mapping.add_mapping(\n            mlc_name,\n            [hf_name],\n            functools.partial(\n                lambda x, dtype: x.astype(dtype),\n                dtype=named_parameters[mlc_name].dtype,\n            ),\n        )\n\n    def _add_vision(name):\n        _add(name, \"model.\" + name)\n\n    _add(\"model.embd.weight\", \"model.embed_tokens.weight\")\n\n    # pylint: disable=line-too-long\n    prefix = \"model.h\"\n    hf_prefix = \"model.layers\"\n    for i in range(model_config.num_hidden_layers):\n        _add(f\"{prefix}.{i}.ln.weight\", f\"{hf_prefix}.{i}.input_layernorm.weight\")\n        _add(\n            f\"{prefix}.{i}.mlp.down_proj.weight\",\n            f\"{hf_prefix}.{i}.mlp.down_proj.weight\",\n        )\n        _add(\n            f\"{prefix}.{i}.mlp.gate_up_proj.weight\",\n            f\"{hf_prefix}.{i}.mlp.gate_up_proj.weight\",\n        )\n        _add(\n            f\"{prefix}.{i}.post_attention_layernorm.weight\",\n            f\"{hf_prefix}.{i}.post_attention_layernorm.weight\",\n        )\n        _add(\n            f\"{prefix}.{i}.mixer.out_proj.weight\",\n            f\"{hf_prefix}.{i}.self_attn.o_proj.weight\",\n        )\n        _add(\n            f\"{prefix}.{i}.mixer.qkv_proj.weight\",\n            f\"{hf_prefix}.{i}.self_attn.qkv_proj.weight\",\n        )\n\n    prefix = \"vision_embed_tokens.img_processor.vision_model.encoder.layers\"\n    for i in range(model_config.vision_config.num_hidden_layers):\n        _add_vision(f\"{prefix}.{i}.layer_norm1.bias\")\n        _add_vision(f\"{prefix}.{i}.layer_norm1.weight\")\n        _add_vision(f\"{prefix}.{i}.layer_norm2.bias\")\n        _add_vision(f\"{prefix}.{i}.layer_norm2.weight\")\n        _add_vision(f\"{prefix}.{i}.mlp.fc1.bias\")\n        _add_vision(f\"{prefix}.{i}.mlp.fc1.weight\")\n        _add_vision(f\"{prefix}.{i}.mlp.fc2.bias\")\n        _add_vision(f\"{prefix}.{i}.mlp.fc2.weight\")\n        _add_vision(f\"{prefix}.{i}.self_attn.k_proj.bias\")\n        _add_vision(f\"{prefix}.{i}.self_attn.k_proj.weight\")\n        _add_vision(f\"{prefix}.{i}.self_attn.out_proj.bias\")\n        _add_vision(f\"{prefix}.{i}.self_attn.out_proj.weight\")\n        _add_vision(f\"{prefix}.{i}.self_attn.q_proj.bias\")\n        _add_vision(f\"{prefix}.{i}.self_attn.q_proj.weight\")\n        _add_vision(f\"{prefix}.{i}.self_attn.v_proj.bias\")\n        _add_vision(f\"{prefix}.{i}.self_attn.v_proj.weight\")\n\n    _add_vision(\"vision_embed_tokens.sub_GN\")\n    _add_vision(\"vision_embed_tokens.glb_GN\")\n    _add_vision(\"vision_embed_tokens.img_processor.vision_model.embeddings.class_embedding\")\n    _add_vision(\"vision_embed_tokens.img_processor.vision_model.embeddings.patch_embedding.weight\")\n    _add_vision(\n        \"vision_embed_tokens.img_processor.vision_model.embeddings.position_embedding.weight\"\n    )\n    _add_vision(\"vision_embed_tokens.img_processor.vision_model.post_layernorm.bias\")\n    _add_vision(\"vision_embed_tokens.img_processor.vision_model.post_layernorm.weight\")\n    _add_vision(\"vision_embed_tokens.img_processor.vision_model.pre_layrnorm.bias\")\n    _add_vision(\"vision_embed_tokens.img_processor.vision_model.pre_layrnorm.weight\")\n\n    prefix = \"vision_embed_tokens.img_projection\"\n    _add(f\"{prefix}.linear_1.bias\", f\"model.{prefix}.0.bias\")\n    _add(f\"{prefix}.linear_1.weight\", f\"model.{prefix}.0.weight\")\n    _add(f\"{prefix}.linear_2.bias\", f\"model.{prefix}.2.bias\")\n    _add(f\"{prefix}.linear_2.weight\", f\"model.{prefix}.2.weight\")\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n    mapping.add_unused(\"model.embed_tokens.weight\")\n\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/phi3v/phi3v_model.py",
    "content": "\"\"\"\nImplementation for Phi architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import relax, target, te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.model.phi3 import Phi3Model\nfrom mlc_llm.model.vision import CLIPVisionConfig, ImageProcessor\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nfrom .phi3v_image import Phi3ImageEmbedding\n\nlogger = logging.getLogger(__name__)\n\nCLIPVISION_DEFAULT_CONFIG = {\n    \"hidden_size\": 1024,\n    \"image_size\": 336,\n    \"intermediate_size\": 4096,\n    \"num_attention_heads\": 16,\n    \"num_hidden_layers\": 24,\n    \"patch_size\": 14,\n    \"projection_dim\": 768,\n    \"layer_norm_eps\": 1e-05,\n    \"vocab_size\": None,\n}\n\n\n@dataclasses.dataclass\nclass Phi3VConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Phi-3 Vision model.\"\"\"\n\n    model_type: str\n    hidden_size: int\n    vocab_size: int\n    num_hidden_layers: int\n    num_attention_heads: int\n    intermediate_size: int\n    rms_norm_eps: float\n    num_key_value_heads: int\n    max_position_embeddings: int\n    vision_config: CLIPVisionConfig = None\n    img_processor: Optional[Dict[str, Any]] = None\n    position_embedding_base: int = 0\n    rope_scaling: Optional[Dict[str, Any]] = None\n    original_max_position_embeddings: int = 0\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    head_dim: int = 0\n    tensor_parallel_shards: int = 1\n    max_batch_size: int = 1\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    # pylint: disable=too-many-branches, consider-using-min-builtin\n    def __post_init__(self):\n        vision_config_dict: Dict[str, Any]\n        if isinstance(self.vision_config, CLIPVisionConfig):\n            vision_config_dict = dataclasses.asdict(self.vision_config)\n        else:\n            vision_config_dict = dict(CLIPVISION_DEFAULT_CONFIG)\n\n        for k, v in vision_config_dict.pop(\"kwargs\", {}).items():\n            vision_config_dict[k] = v\n\n        self.vision_config = CLIPVisionConfig.from_dict(vision_config_dict)\n\n        if self.position_embedding_base == 0:\n            if \"rope_theta\" in self.kwargs:\n                self.position_embedding_base = self.kwargs.pop(\"rope_theta\")\n            else:\n                self.position_embedding_base = 10000\n        if self.rope_scaling is not None:\n            if \"type\" not in self.rope_scaling:\n                self.rope_scaling = None\n            else:\n                if self.rope_scaling[\"type\"] == \"su\":\n                    self.rope_scaling[\"type\"] = \"longrope\"\n\n                assert (\n                    self.rope_scaling[\"type\"] == \"longrope\"\n                ), f\"Unsupported RoPE scaling type {self.rope_scaling['rope_type']} for Phi3\"\n                self.rope_scaling[\"rope_type\"] = self.rope_scaling[\"type\"]\n                (\n                    self.rope_scaling[\"max_position_embeddings\"],\n                    self.rope_scaling[\"original_max_position_embeddings\"],\n                ) = (\n                    self.max_position_embeddings,\n                    self.original_max_position_embeddings,\n                )\n\n        if self.context_window_size == 0:\n            self.context_window_size = self.max_position_embeddings\n\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n        if self.num_key_value_heads == 0 or self.num_key_value_heads is None:\n            self.num_key_value_heads = self.num_attention_heads\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.head_dim * self.num_attention_heads == self.hidden_size\n        assert self.num_attention_heads % self.num_key_value_heads == 0\n\n\n# pylint: disable=invalid-name,missing-docstring, too-many-branches\n\n\n# mypy: disable-error-code=\"arg-type,annotation-unchecked\"\nclass Phi3VForCausalLM(nn.Module):\n    # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: Phi3VConfig) -> None:\n        super().__init__()\n\n        self.config = config\n        self.model = Phi3Model(config)\n        self.lm_head = nn.Linear(config.hidden_size, \"vocab_size\", bias=False)\n        self.vision_embed_tokens = Phi3ImageEmbedding(config)\n        self.image_processor = ImageProcessor()\n        self.num_hidden_layers = config.num_hidden_layers\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.head_dim = config.head_dim\n        self.hidden_size = config.hidden_size\n        self.vocab_size = config.vocab_size\n        self.rope_scaling = config.rope_scaling\n        self.rope_theta = config.position_embedding_base\n        self.rope_ext_factors = (\n            (config.rope_scaling[\"long_factor\"] + config.rope_scaling[\"short_factor\"])\n            if config.rope_scaling is not None\n            else None\n        )\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.dtype = \"float32\"\n        self.image_dtype = (\n            \"uint32\"\n            if target.Target.current() and target.Target.current().kind.name == \"webgpu\"\n            else \"uint8\"\n        )\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        lm_logits = self.lm_head(hidden_states)\n        if lm_logits.dtype != \"float32\":\n            lm_logits = lm_logits.astype(\"float32\")\n        return lm_logits\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.lm_head(hidden_states)\n\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        embeds = self.model.embd(input_ids)\n        return embeds\n\n    # pylint: disable=protected-access\n    def image_preprocess(\n        self, pixel_values: Tensor, resized_height, resized_width, num_crops=16\n    ) -> Tensor:\n        pixel_values = op.permute_dims(pixel_values, axes=(0, 3, 1, 2))  # NHWC -> NCHW\n        pixel_values = self.image_processor.resize(\n            pixel_values, params={\"height\": resized_height, \"width\": resized_width}\n        )\n        pixel_values = self.image_processor.pad(pixel_values, dtype=self.image_dtype)\n        pixel_values = self.image_processor.rescale(pixel_values)\n        pixel_values = self.image_processor.normalize(pixel_values)\n        global_image = self.image_processor.resize(\n            pixel_values, params={\"height\": 336, \"width\": 336}\n        )\n        global_image = op.wrap_nested(\n            relax.BlockBuilder()\n            .current()\n            .match_cast(\n                global_image._expr,\n                relax.TensorStructInfo(\n                    [global_image.shape[0], global_image.shape[1], 336, 336],\n                    global_image.dtype,\n                ),\n            ),\n            \"global_image\",\n        )\n\n        n, c, h, w = pixel_values.shape  # pylint: disable=unused-variable\n        assert isinstance(h, tir.Mul) and isinstance(h.b, tir.IntImm) and h.b.value == 336\n        pixel_values = op.reshape(pixel_values, shape=(1, 3, h.a, 336, w // 336, 336))\n        pixel_values = op.permute_dims(pixel_values, axes=(0, 2, 4, 1, 3, 5))\n        pixel_values = op.reshape(pixel_values, shape=(-1, 3, 336, 336))\n        combined_image = op.concat([global_image, pixel_values], dim=0)\n\n        # pad to max num crops tensor\n        b, c, h, w = combined_image.shape\n        zeros = op.zeros((num_crops + 1 - b, c, h, w))\n        combined_image = op.concat([combined_image, zeros], dim=0)\n\n        combined_image = op.wrap_nested(\n            relax.BlockBuilder()\n            .current()\n            .match_cast(\n                combined_image._expr,\n                relax.TensorStructInfo([num_crops + 1, c, h, w], combined_image.dtype),\n            ),\n            \"combined_image\",\n        )\n\n        return combined_image\n\n    def image_embed(  # pylint: disable=too-many-arguments\n        self,\n        pixel_values: Tensor,\n        resized_height,\n        resized_width,\n        crop_height,\n        crop_width,\n    ) -> Tensor:\n        n, h, w, c = pixel_values.shape  # pylint: disable=unused-variable\n        pixel_values = self.image_preprocess(pixel_values, resized_height, resized_width)\n        pixel_values = pixel_values.astype(self.dtype)\n        return self.vision_embed_tokens(pixel_values, crop_height, crop_width)\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scaling=self.rope_scaling,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            rope_ext_factors=self.rope_ext_factors,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"image_embed\": {\n                \"pixel_values\": nn.spec.Tensor(\n                    [1, \"image_height\", \"image_width\", 3], self.image_dtype\n                ),\n                \"resized_height\": nn.spec.Int(),\n                \"resized_width\": nn.spec.Int(),\n                \"crop_height\": nn.spec.Int(),\n                \"crop_width\": nn.spec.Int(),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/qwen/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/qwen/qwen_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's QWen parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\n\nfrom .qwen_model import QWenLMHeadModel\n\nhuggingface = make_standard_hf_loader(\n    model_cls=QWenLMHeadModel,\n    layer_prefix=\"transformer.h\",\n    qkv_names=(),\n    include_qkv=False,\n    gate_up_names=(\"w1\", \"w2\"),\n)\n"
  },
  {
    "path": "python/mlc_llm/model/qwen/qwen_model.py",
    "content": "\"\"\"\nImplementation for QWEN architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass QWenConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the QWen model.\"\"\"\n\n    vocab_size: int\n    hidden_size: int\n    num_hidden_layers: int\n    num_attention_heads: int\n    layer_norm_epsilon: float\n    scale_attn_weights: bool\n    kv_channels: int\n    rotary_emb_base: int\n    intermediate_size: int\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    max_batch_size: int = 1\n    head_dim: int = 0\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.head_dim * self.num_attention_heads == self.hidden_size\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass QWenAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: QWenConfig):\n        self.hidden_size = config.hidden_size\n        if config.num_attention_heads % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.num_attention_heads} attention heads \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.num_heads = config.num_attention_heads // config.tensor_parallel_shards\n        self.head_dim = config.head_dim\n\n        self.c_attn = nn.Linear(config.hidden_size, 3 * self.num_heads * self.head_dim, bias=True)\n\n        self.c_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)\n\n    def forward(  # pylint: disable=too-many-locals\n        self,\n        hidden_states: Tensor,\n        paged_kv_cache: PagedKVCache,\n        layer_id: int,\n    ):\n        d, h = self.head_dim, self.num_heads\n        b, s, _ = hidden_states.shape\n\n        qkv = self.c_attn(hidden_states)\n        qkv = op.reshape(qkv, (b, s, 3 * h, d))\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h * d),\n        )\n        return self.c_proj(output)\n\n\nclass QWenMLP(nn.Module):\n    def __init__(self, config: QWenConfig):\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.gate_up_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=self.intermediate_size,\n            bias=False,\n        )\n        self.c_proj = nn.Linear(self.intermediate_size // 2, config.hidden_size, bias=False)\n\n    def forward(self, x: Tensor):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.c_proj(x1 * op.silu(x2))\n\n\nclass QWenBlock(nn.Module):\n    def __init__(self, config: QWenConfig):\n        rms_norm_eps = config.layer_norm_epsilon\n        self.attn = QWenAttention(config)\n        self.mlp = QWenMLP(config)\n        self.ln_1 = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)\n        self.ln_2 = nn.RMSNorm(config.hidden_size, -1, rms_norm_eps, bias=False)\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.attn.num_heads * hd\n            k = self.attn.num_heads * hd\n            v = self.attn.num_heads * hd\n            i = self.mlp.intermediate_size // 2\n            _set(\n                self.attn.c_attn.weight,\n                tp.ShardSingleDim(\"_shard_qkv_weight\", dim=0, segs=[q, k, v]),\n            )\n            _set(\n                self.attn.c_attn.bias,\n                tp.ShardSingleDim(\"_shard_qkv_bias\", dim=0, segs=[q, k, v]),\n            )\n            _set(self.attn.c_proj.weight, tp.ShardSingleDim(\"_shard_attn_c_proj\", dim=1))\n            _set(\n                self.mlp.gate_up_proj.weight,\n                tp.ShardSingleDim(\"_shard_mlp_gate_up_proj\", segs=[i, i], dim=0),\n            )\n            _set(self.mlp.c_proj.weight, tp.ShardSingleDim(\"_shard_mlp_c_proj\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.attn(self.ln_1(hidden_states), paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.mlp(self.ln_2(hidden_states))\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass QWenModel(nn.Module):\n    def __init__(self, config: QWenConfig):\n        assert config.hidden_size % config.num_attention_heads == 0\n        self.wte = nn.Embedding(config.vocab_size, config.hidden_size)\n        self.h = nn.ModuleList([QWenBlock(config) for _ in range(config.num_hidden_layers)])\n        self.ln_f = nn.RMSNorm(config.hidden_size, -1, config.layer_norm_epsilon, bias=False)\n\n    def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = inputs\n        for layer_id, layer in enumerate(self.h):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.ln_f(hidden_states)\n        return hidden_states\n\n\nclass QWenLMHeadModel(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: QWenConfig):\n        self.transformer = QWenModel(config)\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False, dtype=\"float32\")\n        self.hidden_size = config.hidden_size\n        self.vocab_size = config.vocab_size\n        self.num_hidden_layers = config.num_hidden_layers\n        self.num_attention_heads = config.num_attention_heads\n        self.head_dim = config.head_dim\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.rotary_emb_base = config.rotary_emb_base\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        inputs: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n        hidden_states = self.transformer(inputs, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.transformer.wte(input_ids)\n\n    def prefill(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.transformer(inputs, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(\n            _index,\n            name_hint=\"index\",\n            args=[hidden_states],\n        )\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.transformer(inputs, paged_kv_cache)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(self, inputs: Tensor, logit_positions: Tensor, paged_kv_cache: PagedKVCache):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(inputs, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(inputs, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(inputs, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rotary_emb_base,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"inputs\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"inputs\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"inputs\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"inputs\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"inputs\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/qwen2/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/qwen2/qwen2_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's QWen2 parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\n\nfrom .qwen2_model import QWen2LMHeadModel\n\nhuggingface = make_standard_hf_loader(\n    model_cls=QWen2LMHeadModel,\n    qkv_target_name=\"c_attn\",\n    add_qkv_bias=True,\n    add_unused=[\"rotary_emb.inv_freq\"],\n)\n"
  },
  {
    "path": "python/mlc_llm/model/qwen2/qwen2_model.py",
    "content": "\"\"\"\nImplementation for QWEN2 architecture.\n\"\"\"\n\nimport dataclasses\nfrom functools import partial\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass QWen2Config(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the QWen2 model.\"\"\"\n\n    hidden_act: str\n    hidden_size: int\n    intermediate_size: int\n    num_attention_heads: int\n    num_hidden_layers: int\n    num_key_value_heads: int\n    rms_norm_eps: float\n    rope_theta: int\n    vocab_size: int\n    tie_word_embeddings: bool = False\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    head_dim: int = 0\n    dtype: str = \"float32\"\n    max_batch_size: int = 1\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.head_dim * self.num_attention_heads == self.hidden_size\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring,too-many-locals\n\n\nclass QWen2Attention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: QWen2Config):\n        self.head_dim = config.head_dim\n        if config.num_key_value_heads % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.num_key_value_heads} key-value attention heads \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.num_attention_heads = config.num_attention_heads // config.tensor_parallel_shards\n        self.num_key_value_heads = config.num_key_value_heads // config.tensor_parallel_shards\n        self.rope_theta = config.rope_theta\n\n        self.c_attn = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=(2 * self.num_key_value_heads + self.num_attention_heads) * self.head_dim,\n            bias=True,\n        )\n        self.o_proj = nn.Linear(\n            self.num_attention_heads * self.head_dim, config.hidden_size, bias=False\n        )\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q, h_kv = self.head_dim, self.num_attention_heads, self.num_key_value_heads\n        b, s, _ = hidden_states.shape\n        qkv = self.c_attn(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_attention_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        attn_output = self.o_proj(output)\n        return attn_output\n\n\nACT2FN = {\n    \"gelu\": partial(nn.gelu, approximate=False),\n    \"relu\": nn.relu,\n    \"silu\": nn.silu,\n    \"swish\": nn.silu,\n    \"gelu_new\": partial(nn.gelu, approximate=True),\n}\n\n\nclass Qwen2Embedding(nn.Embedding):\n    \"\"\"The embedding module specialized for Qwen2 so that\n    it can be shared with the final lm_head.\n    \"\"\"\n\n    def lm_head_forward(self, x: nn.Tensor):\n        \"\"\"The lm_head forwarding, which transposes the weight and multiplies\n        with the input tensor.\n        \"\"\"\n        weight = nn.op.permute_dims(self.weight)\n        return nn.op.matmul(x, weight, out_dtype=\"float32\")\n\n\nclass QWen2MLP(nn.Module):\n    def __init__(self, config: QWen2Config):\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.gate_up_proj = nn.Linear(config.hidden_size, 2 * self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x: Tensor):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(self.act_fn(x1) * x2)\n\n\nclass QWen2DecoderLayer(nn.Module):\n    def __init__(self, config: QWen2Config):\n        self.self_attn = QWen2Attention(config)\n        self.mlp = QWen2MLP(config)\n        self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n        self.post_attention_layernorm = nn.RMSNorm(\n            config.hidden_size, -1, config.rms_norm_eps, bias=False\n        )\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_attention_heads * hd\n            k = self.self_attn.num_key_value_heads * hd\n            v = self.self_attn.num_key_value_heads * hd\n            i = self.mlp.intermediate_size\n            _set(\n                self.self_attn.c_attn.weight,\n                tp.ShardSingleDim(\"_shard_qkv_weight\", dim=0, segs=[q, k, v]),\n            )\n            _set(\n                self.self_attn.c_attn.bias,\n                tp.ShardSingleDim(\"_shard_qkv_bias\", dim=0, segs=[q, k, v]),\n            )\n            _set(self.self_attn.o_proj.weight, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.gate_up_proj.weight,\n                tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n            )\n            _set(self.mlp.down_proj.weight, tp.ShardSingleDim(\"_shard_mlp_down\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.input_layernorm(hidden_states)\n        out = self.self_attn(out, paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.post_attention_layernorm(hidden_states)\n        out = self.mlp(out)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass QWen2Model(nn.Module):\n    def __init__(self, config: QWen2Config):\n        self.embed_tokens = Qwen2Embedding(config.vocab_size, config.hidden_size)\n        self.layers = nn.ModuleList(\n            [QWen2DecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n\n    def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = inputs\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass QWen2LMHeadModel(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: QWen2Config):\n        self.model = QWen2Model(config)\n        self.tie_word_embeddings = config.tie_word_embeddings\n        if not config.tie_word_embeddings:\n            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.dtype = config.dtype\n        self.hidden_size = config.hidden_size\n        self.num_hidden_layers = config.num_hidden_layers\n        self.intermediate_size = config.intermediate_size\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.rms_norm_eps = config.rms_norm_eps\n        self.rope_theta = config.rope_theta\n        self.vocab_size = config.vocab_size\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.head_dim = config.head_dim\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n\n        if self.tie_word_embeddings:\n            logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        if self.tie_word_embeddings:\n            logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        if self.tie_word_embeddings:\n            logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/qwen2_5_vl/__init__.py",
    "content": "\"\"\"Qwen2.5-VL architecture entry.\"\"\"\n\nfrom .qwen2_5_vl_model import Qwen25VLConfig, Qwen25VLLMHeadModel  # noqa: F401\n"
  },
  {
    "path": "python/mlc_llm/model/qwen2_5_vl/qwen2_5_vl_model.py",
    "content": "\"\"\"Partial Qwen2.5-VL implementation focused on decoder-side MRoPE support.\n\nThis file intentionally does not provide complete end-to-end Qwen2.5-VL support yet.\nThe current scope is:\n\n- decoder-side text model structure,\n- multimodal rotary embedding generation/application,\n- position-id layout compatibility for the MRoPE path.\n\nMissing pieces for full Qwen2.5-VL support include:\n\n- vision tower / multimodal projector and image-video embedding path,\n- model registry / loader / quantization / preset integration,\n- end-to-end multimodal preprocessing and compile/chat wiring.\n\"\"\"\n\n# pylint: disable=missing-function-docstring,missing-class-docstring\n\nimport dataclasses\nfrom functools import partial\nfrom typing import Any, Dict, Optional, Tuple\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.op.mrope import (\n    MultimodalRotaryEmbedding,\n    VisionPositionMetadata,\n    apply_multimodal_rotary_pos_emb,\n)\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\n\nACT2FN = {\n    \"gelu\": partial(nn.gelu, approximate=False),\n    \"relu\": nn.relu,\n    \"silu\": nn.silu,\n    \"swish\": nn.silu,\n    \"gelu_new\": partial(nn.gelu, approximate=True),\n}\n\n\n@dataclasses.dataclass\nclass Qwen25VLVisionTokenConfig:\n    \"\"\"Vision token IDs used by Qwen2.5-VL.\"\"\"\n\n    image_token_id: int = 151655\n    video_token_id: int = 151656\n    vision_start_token_id: int = 151652\n    vision_end_token_id: int = 151653\n\n\n@dataclasses.dataclass\nclass Qwen25VLVisionGridConfig:\n    \"\"\"Vision grid configuration for multimodal position IDs.\"\"\"\n\n    spatial_merge_size: int = 2\n    temporal_patch_size: int = 2\n    tokens_per_second: float = 4.0\n\n\n@dataclasses.dataclass(frozen=True)\nclass Qwen25VLAttentionState:\n    \"\"\"Derived attention dimensions used across Qwen2.5-VL attention code paths.\"\"\"\n\n    head_dim: int\n    num_attention_heads: int\n    num_key_value_heads: int\n    mrope_section: Tuple[int, int, int]\n    softmax_scale: float\n\n\n@dataclasses.dataclass\nclass Qwen25VLConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration for the Qwen2.5-VL model.\"\"\"\n\n    hidden_act: str\n    hidden_size: int\n    intermediate_size: int\n    num_attention_heads: int\n    num_hidden_layers: int\n    num_key_value_heads: int\n    rms_norm_eps: float\n    rope_theta: float\n    vocab_size: int\n    tie_word_embeddings: bool = False\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    head_dim: int = 0\n    dtype: str = \"float32\"\n    max_batch_size: int = 1\n    rope_parameters: Optional[Dict[str, Any]] = None\n    mrope_section: Optional[Tuple[int, int, int]] = None\n    vision_tokens: Qwen25VLVisionTokenConfig = dataclasses.field(\n        default_factory=Qwen25VLVisionTokenConfig\n    )\n    vision_grid: Qwen25VLVisionGridConfig = dataclasses.field(\n        default_factory=Qwen25VLVisionGridConfig\n    )\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):  # pylint: disable=too-many-branches\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        if self.prefill_chunk_size == 0:\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n        rope_scaling = self.kwargs.pop(\"rope_scaling\", None)\n        if self.rope_parameters is None:\n            self.rope_parameters = rope_scaling or {}\n        if self.mrope_section is None:\n            section = self.rope_parameters.get(\"mrope_section\")\n            if section is None and rope_scaling is not None:\n                section = rope_scaling.get(\"mrope_section\")\n            if section is None:\n                raise ValueError(\"`mrope_section` must be provided for Qwen2.5-VL.\")\n            self.mrope_section = tuple(int(i) for i in section)\n        if len(self.mrope_section) != 3:\n            raise ValueError(f\"mrope_section must contain 3 integers, got {self.mrope_section}.\")\n\n        for key in [\n            \"image_token_id\",\n            \"video_token_id\",\n            \"vision_start_token_id\",\n            \"vision_end_token_id\",\n        ]:\n            if key in self.kwargs:\n                setattr(self.vision_tokens, key, int(self.kwargs.pop(key)))\n        for key in [\"spatial_merge_size\", \"temporal_patch_size\"]:\n            if key in self.kwargs:\n                setattr(self.vision_grid, key, int(self.kwargs.pop(key)))\n        if \"tokens_per_second\" in self.kwargs:\n            self.vision_grid.tokens_per_second = float(self.kwargs.pop(\"tokens_per_second\"))\n\n        vision_cfg = self.kwargs.pop(\"vision_config\", {})\n        if vision_cfg:\n            if not isinstance(vision_cfg, dict):\n                raise ValueError(f\"vision_config must be a dict, got {type(vision_cfg)}.\")\n            self.vision_grid.spatial_merge_size = int(\n                vision_cfg.get(\"spatial_merge_size\", self.vision_grid.spatial_merge_size)\n            )\n            self.vision_grid.temporal_patch_size = int(\n                vision_cfg.get(\"temporal_patch_size\", self.vision_grid.temporal_patch_size)\n            )\n            self.vision_grid.tokens_per_second = float(\n                vision_cfg.get(\"tokens_per_second\", self.vision_grid.tokens_per_second)\n            )\n\n    @property\n    def image_token_id(self) -> int:\n        return self.vision_tokens.image_token_id\n\n    @property\n    def video_token_id(self) -> int:\n        return self.vision_tokens.video_token_id\n\n    @property\n    def vision_start_token_id(self) -> int:\n        return self.vision_tokens.vision_start_token_id\n\n    @property\n    def vision_end_token_id(self) -> int:\n        return self.vision_tokens.vision_end_token_id\n\n    @property\n    def spatial_merge_size(self) -> int:\n        return self.vision_grid.spatial_merge_size\n\n    @property\n    def temporal_patch_size(self) -> int:\n        return self.vision_grid.temporal_patch_size\n\n    @property\n    def tokens_per_second(self) -> float:\n        return self.vision_grid.tokens_per_second\n\n    @property\n    def vision_metadata(self) -> VisionPositionMetadata:\n        return VisionPositionMetadata(\n            vision_start_token_id=self.vision_tokens.vision_start_token_id,\n            image_token_id=self.vision_tokens.image_token_id,\n            video_token_id=self.vision_tokens.video_token_id,\n            spatial_merge_size=self.vision_grid.spatial_merge_size,\n            tokens_per_second=self.vision_grid.tokens_per_second,\n        )\n\n\nclass Qwen25VLEmbedding(nn.Embedding):\n    \"\"\"Embedding module shared with LM head.\"\"\"\n\n    def lm_head_forward(self, x: Tensor):\n        weight = nn.op.permute_dims(self.weight)\n        return nn.op.matmul(x, weight, out_dtype=\"float32\")\n\n\nclass Qwen25VLAttention(nn.Module):\n    def __init__(self, config: Qwen25VLConfig):\n        if config.num_key_value_heads % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.num_key_value_heads} key-value heads \"\n                f\"evenly to {config.tensor_parallel_shards} shards.\"\n            )\n        head_dim = config.head_dim\n        num_attention_heads = config.num_attention_heads // config.tensor_parallel_shards\n        num_key_value_heads = config.num_key_value_heads // config.tensor_parallel_shards\n        mrope_section: Tuple[int, int, int] = (\n            config.mrope_section if config.mrope_section is not None else (0, 0, 0)\n        )\n        self.state = Qwen25VLAttentionState(\n            head_dim=head_dim,\n            num_attention_heads=num_attention_heads,\n            num_key_value_heads=num_key_value_heads,\n            mrope_section=mrope_section,\n            softmax_scale=head_dim**-0.5,\n        )\n\n        out_features = (num_attention_heads + 2 * num_key_value_heads) * head_dim\n        self.c_attn = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=out_features,\n            bias=True,\n        )\n        self.o_proj = nn.Linear(\n            num_attention_heads * head_dim,\n            config.hidden_size,\n            bias=False,\n        )\n\n    @property\n    def head_dim(self) -> int:\n        return self.state.head_dim\n\n    @property\n    def num_attention_heads(self) -> int:\n        return self.state.num_attention_heads\n\n    @property\n    def num_key_value_heads(self) -> int:\n        return self.state.num_key_value_heads\n\n    def forward(  # pylint: disable=too-many-locals\n        self,\n        hidden_states: Tensor,\n        paged_kv_cache: PagedKVCache,\n        layer_id: int,\n        position_embeddings: Tuple[Tensor, Tensor],\n    ):\n        d, h_q, h_kv = (\n            self.state.head_dim,\n            self.state.num_attention_heads,\n            self.state.num_key_value_heads,\n        )\n        b, s, _ = hidden_states.shape\n        qkv = self.c_attn(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        q, k, v = op.split(qkv, [h_q, h_q + h_kv], axis=2)\n        cos, sin = position_embeddings\n        q, k = apply_multimodal_rotary_pos_emb(q, k, cos, sin, self.state.mrope_section)\n        output, _ = paged_kv_cache.self_attention(layer_id, q, k, v, self.state.softmax_scale)\n        output = op.reshape(output, (b, s, h_q * d))\n        return self.o_proj(output)\n\n\nclass Qwen25VLMLP(nn.Module):\n    def __init__(self, config: Qwen25VLConfig):\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} shards.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.gate_up_proj = nn.Linear(config.hidden_size, 2 * self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x: Tensor):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(self.act_fn(x1) * x2)\n\n\nclass Qwen25VLDecoderLayer(nn.Module):\n    def __init__(self, config: Qwen25VLConfig):\n        self.self_attn = Qwen25VLAttention(config)\n        self.mlp = Qwen25VLMLP(config)\n        self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n        self.post_attention_layernorm = nn.RMSNorm(\n            config.hidden_size, -1, config.rms_norm_eps, bias=False\n        )\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self._set_tp(config)\n\n    def _set_tp(self, config: Qwen25VLConfig):\n        def _set(layer, hint):\n            layer.attrs[\"shard_strategy\"] = hint\n\n        hd = config.head_dim\n        q = self.self_attn.num_attention_heads * hd\n        k = self.self_attn.num_key_value_heads * hd\n        v = self.self_attn.num_key_value_heads * hd\n        i = self.mlp.intermediate_size\n        _set(\n            self.self_attn.c_attn.weight,\n            tp.ShardSingleDim(\"_shard_qkv_weight\", dim=0, segs=[q, k, v]),\n        )\n        _set(\n            self.self_attn.c_attn.bias,\n            tp.ShardSingleDim(\"_shard_qkv_bias\", dim=0, segs=[q, k, v]),\n        )\n        _set(self.self_attn.o_proj.weight, tp.ShardSingleDim(\"_shard_o\", dim=1))\n        _set(\n            self.mlp.gate_up_proj.weight,\n            tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n        )\n        _set(self.mlp.down_proj.weight, tp.ShardSingleDim(\"_shard_mlp_down\", dim=1))\n\n    def forward(\n        self,\n        hidden_states: Tensor,\n        paged_kv_cache: PagedKVCache,\n        layer_id: int,\n        position_embeddings: Tuple[Tensor, Tensor],\n    ):\n        out = self.input_layernorm(hidden_states)\n        out = self.self_attn(out, paged_kv_cache, layer_id, position_embeddings)\n        hidden_states = self._apply_residual(out, hidden_states)\n        out = self.post_attention_layernorm(hidden_states)\n        out = self.mlp(out)\n        hidden_states = self._apply_residual(out, hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out: Tensor, residual: Tensor) -> Tensor:\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass Qwen25VLModel(nn.Module):\n    def __init__(self, config: Qwen25VLConfig):\n        self.embed_tokens = Qwen25VLEmbedding(config.vocab_size, config.hidden_size)\n        self.layers = nn.ModuleList(\n            [Qwen25VLDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n        attention_scaling = config.rope_parameters.get(\"attention_scaling\", 1.0)\n        self.rotary_emb = MultimodalRotaryEmbedding(\n            head_dim=config.head_dim,\n            theta=config.rope_theta,\n            mrope_section=config.mrope_section,\n            attention_scaling=attention_scaling,\n        )\n\n    def forward(\n        self,\n        inputs: Tensor,\n        position_ids: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        hidden_states = inputs\n        cos, sin = self.rotary_emb(hidden_states, position_ids)\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id, (cos, sin))\n        return self.norm(hidden_states)\n\n\nclass Qwen25VLLMHeadModel(nn.Module):\n    def __init__(self, config: Qwen25VLConfig):\n        self.config = config\n        self.model = Qwen25VLModel(config)\n        self.tie_word_embeddings = config.tie_word_embeddings\n        if not self.tie_word_embeddings:\n            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.dtype = config.dtype\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def _apply_lm_head(self, hidden_states: Tensor):\n        if self.tie_word_embeddings:\n            logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def _set_mrope_delta(self, paged_kv_cache: PagedKVCache, deltas: Tensor):\n        setattr(paged_kv_cache, \"_mrope_delta\", deltas)\n        return deltas\n\n    def _get_mrope_delta(self, paged_kv_cache: PagedKVCache, batch: int) -> Tensor:\n        delta = getattr(paged_kv_cache, \"_mrope_delta\", None)\n        if delta is None:\n            delta = op.zeros((batch, 1), \"int32\")\n            setattr(paged_kv_cache, \"_mrope_delta\", delta)\n        return delta\n\n    def _build_decode_position_ids(\n        self,\n        seq_len: int,\n        paged_kv_cache: PagedKVCache,\n        batch: int,\n    ) -> Tensor:\n        base = paged_kv_cache.get_query_positions(seq_len)\n        base = op.reshape(base, (1, seq_len))\n        base = op.broadcast_to(base, (batch, seq_len))\n        delta = self._get_mrope_delta(paged_kv_cache, batch)\n        base = base + delta\n        base = op.unsqueeze(base, dim=0)\n        return op.broadcast_to(base, (3, batch, seq_len))\n\n    def prefill(\n        self,\n        input_embed: Tensor,\n        position_ids: Tensor,\n        mrope_deltas: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        op_ext.configure()\n        self._set_mrope_delta(paged_kv_cache, mrope_deltas)\n        hidden_states = self.model(input_embed, position_ids, paged_kv_cache)\n\n        def _index(x: te.Tensor):\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self._apply_lm_head(hidden_states)\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n        b, s, _ = input_embed.shape\n        position_ids = self._build_decode_position_ids(s, paged_kv_cache, b)\n        hidden_states = self.model(input_embed, position_ids, paged_kv_cache)\n        logits = self._apply_lm_head(hidden_states)\n        return logits, paged_kv_cache\n\n    def batch_prefill(  # pylint: disable=too-many-arguments\n        self,\n        input_embeds: Tensor,\n        position_ids: Tensor,\n        mrope_deltas: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.config.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(\n            input_embeds, position_ids, mrope_deltas, logit_positions, paged_kv_cache\n        )\n        return logits, paged_kv_cache\n\n    def batch_forward(  # pylint: disable=too-many-arguments\n        self,\n        input_embeds: Tensor,\n        position_ids: Tensor,\n        mrope_deltas: Tensor,\n        logit_positions: Optional[Tensor],\n        paged_kv_cache: PagedKVCache,\n    ):\n        op_ext.configure()\n        self._set_mrope_delta(paged_kv_cache, mrope_deltas)\n        hidden_states = self.model(input_embeds, position_ids, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        return self._apply_lm_head(hidden_states)\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n        b, s, _ = input_embeds.shape\n        position_ids = self._build_decode_position_ids(s, paged_kv_cache, b)\n        hidden_states = self.model(input_embeds, position_ids, paged_kv_cache)\n        logits = self._apply_lm_head(hidden_states)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        return self.batch_decode(input_embeds, paged_kv_cache)\n\n    def embed(self, input_ids: Tensor):\n        if self.config.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        cfg = self.config\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=cfg.num_hidden_layers,\n            num_attention_heads=cfg.num_attention_heads // cfg.tensor_parallel_shards,\n            num_key_value_heads=cfg.num_key_value_heads // cfg.tensor_parallel_shards,\n            qk_head_dim=cfg.head_dim,\n            v_head_dim=cfg.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scaling=cfg.rope_parameters,\n            rope_scale=1,\n            rope_theta=int(cfg.rope_theta),\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        cfg = self.config\n        seq_len = \"seq_len\"\n        hidden = cfg.hidden_size\n        dtype = self.dtype\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([seq_len], \"int32\"),\n                \"$\": {\"param_mode\": \"packed\", \"effect_mode\": \"none\"},\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, seq_len, hidden], dtype),\n                \"position_ids\": nn.spec.Tensor([3, 1, seq_len], \"int32\"),\n                \"mrope_deltas\": nn.spec.Tensor([1, 1], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\"param_mode\": \"packed\", \"effect_mode\": \"none\"},\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, hidden], dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\"param_mode\": \"packed\", \"effect_mode\": \"none\"},\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, seq_len, hidden], dtype),\n                \"position_ids\": nn.spec.Tensor([3, 1, seq_len], \"int32\"),\n                \"mrope_deltas\": nn.spec.Tensor([1, 1], \"int32\"),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\"param_mode\": \"packed\", \"effect_mode\": \"none\"},\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, hidden], dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\"param_mode\": \"packed\", \"effect_mode\": \"none\"},\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, hidden], dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\"param_mode\": \"packed\", \"effect_mode\": \"none\"},\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\"param_mode\": \"none\", \"effect_mode\": \"none\"},\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/qwen2_moe/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/qwen2_moe/qwen2_moe_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's QWen2 parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.quantization import Quantization\n\nfrom .qwen2_moe_model import Qwen2MoeConfig, Qwen2MoeForCausalLM\n\n\ndef huggingface(model_config: Qwen2MoeConfig, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : Qwen2MoeConfig\n        The configuration of the Qwen2Moe model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = Qwen2MoeForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    for i in range(model_config.num_hidden_layers):\n        # map attention weight\n        attn = f\"model.layers.{i}.self_attn\"\n        for weight_type in [\"weight\", \"bias\"]:\n            mlc_name = f\"{attn}.c_attn.{weight_type}\"\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{attn}.q_proj.{weight_type}\",\n                    f\"{attn}.k_proj.{weight_type}\",\n                    f\"{attn}.v_proj.{weight_type}\",\n                ],\n                functools.partial(\n                    lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n        # map mlp shared expert weight\n        mlp = f\"model.layers.{i}.mlp\"\n        shared_expert = f\"{mlp}.shared_expert\"\n        mlc_name = f\"{shared_expert}.gate_up_proj.weight\"\n        mlc_param = named_parameters[mlc_name]\n        mapping.add_mapping(\n            mlc_name,\n            [\n                f\"{shared_expert}.gate_proj.weight\",\n                f\"{shared_expert}.up_proj.weight\",\n            ],\n            functools.partial(\n                lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),\n                dtype=mlc_param.dtype,\n            ),\n        )\n        # map mlp moe gate and up weight\n        mlc_name = f\"{mlp}.moe_gate_up_proj.weight\"\n\n        def combine_expert_gate_up(*hf_params, dtype):\n            stack = []\n            for i in range(0, len(hf_params), 2):\n                stack.append(np.concatenate([hf_params[i], hf_params[i + 1]], axis=0))\n            return np.stack(stack, axis=0).astype(dtype)\n\n        mapping.add_mapping(\n            mlc_name,\n            functools.reduce(\n                lambda a, b: a + b,\n                [\n                    [\n                        f\"{mlp}.experts.{expert_id}.gate_proj.weight\",\n                        f\"{mlp}.experts.{expert_id}.up_proj.weight\",\n                    ]\n                    for expert_id in range(model_config.num_experts)\n                ],\n            ),\n            functools.partial(\n                combine_expert_gate_up,\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n        # map mlp moe gate and up weight\n        mlc_name = f\"{mlp}.moe_down_proj.weight\"\n        mlc_param = named_parameters[mlc_name]\n        mapping.add_mapping(\n            mlc_name,\n            [\n                f\"{mlp}.experts.{expert_id}.down_proj.weight\"\n                for expert_id in range(model_config.num_experts)\n            ],\n            functools.partial(\n                lambda *hf_params, dtype: np.stack(hf_params, axis=0).astype(dtype),\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/qwen2_moe/qwen2_moe_model.py",
    "content": "\"\"\"\nImplementation for QWEN2MOE architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.model.qwen2.qwen2_model import ACT2FN, QWen2Attention, QWen2Config\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.nn.expert import MixtralExperts\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass Qwen2MoeConfig(QWen2Config):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Qwen2Moe model.\"\"\"\n\n    moe_intermediate_size: int = 0\n    shared_expert_intermediate_size: int = 0\n    num_experts_per_tok: int = 0\n    num_experts: int = 0\n    decoder_sparse_step: int = 0\n    norm_topk_prob: bool = False\n\n\n# pylint: disable=invalid-name,missing-docstring,too-many-locals\n\n\nclass Qwen2MoeMLP(nn.Module):\n    def __init__(self, config: Qwen2MoeConfig, intermediate_size: Optional[int] = None):\n        intermediate_size = intermediate_size or config.intermediate_size\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MoE MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = intermediate_size // config.tensor_parallel_shards\n        self.gate_up_proj = nn.Linear(config.hidden_size, 2 * self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x: Tensor):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(self.act_fn(x1) * x2)\n\n\nclass Qwen2MoeSparseMoeBlock(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"MoE layer for Qwen2MoE model.\"\"\"\n\n    def __init__(self, config: Qwen2MoeConfig):\n        super().__init__()\n        self.num_experts_per_tok = config.num_experts_per_tok\n        self.num_experts = config.num_experts\n        if config.moe_intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MoE intermediate size {config.moe_intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.moe_intermediate_size = config.moe_intermediate_size // config.tensor_parallel_shards\n        self.norm_topk_prob = config.norm_topk_prob\n        self.shared_expert = Qwen2MoeMLP(config, config.shared_expert_intermediate_size)\n        self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias=False)\n\n        self.gate = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=config.num_experts,\n            bias=False,\n        )\n        self.moe_gate_up_proj = MixtralExperts(\n            self.num_experts,\n            in_features=config.hidden_size,\n            out_features=2 * self.moe_intermediate_size,\n        )\n        self.moe_down_proj = MixtralExperts(\n            self.num_experts,\n            in_features=self.moe_intermediate_size,\n            out_features=config.hidden_size,\n        )\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x: Tensor):\n        def _expert_forward(x: Tensor, indptr: Tensor):\n            x1_x2 = self.moe_gate_up_proj(x, indptr)\n            x1, x2 = op.split(x1_x2, indices_or_sections=2, axis=-1)\n            x = self.moe_down_proj(self.act_fn(x1) * x2, indptr)\n            return x\n\n        experts_per_tok = self.num_experts_per_tok\n        num_experts = self.num_experts\n        batch_size, seq_len, hidden_size = x.shape\n        num_tokens = batch_size * seq_len\n        x = x.reshape(num_tokens, hidden_size)\n        gate = self.gate(x)\n        # expert_weights: [num_tokens, experts_per_tok]\n        # expert_indices: [num_tokens, experts_per_tok]\n        expert_weights, expert_indices = op_ext.moe_misc.gating_softmax_topk(\n            gate, experts_per_tok, norm_topk_prob=self.norm_topk_prob\n        )\n        if num_tokens == 1:\n            # x: [num_tokens * experts_per_tok, hidden_size]\n            moe_hidden_states = _expert_forward(x, expert_indices)\n        else:\n            # cumsum: [num_tokens * local_experts]\n            cumsum = op_ext.moe_misc.moe_cumsum(expert_indices, num_experts)\n            # indices: [num_tokens * experts_per_tok]\n            reverse_indices, token_indices = op_ext.moe_misc.get_indices(cumsum, expert_indices)\n            # indptr: [num_local_experts + 1]\n            indptr = op_ext.moe_misc.get_indptr(\n                cumsum, num_experts, num_tokens, inclusive=False, out_dtype=\"int32\"\n            )\n            # x: [num_tokens * experts_per_tok, hidden_size]\n            moe_hidden_states = op.take(x, token_indices, axis=0)\n            moe_hidden_states = _expert_forward(moe_hidden_states, indptr)\n            moe_hidden_states = op_ext.moe_misc.scatter_output(moe_hidden_states, reverse_indices)\n        # moe_hidden_states: [num_tokens, experts_per_tok, hidden_size]\n        expert_weights = expert_weights.reshape(num_tokens, experts_per_tok, 1)\n        moe_hidden_states = (\n            moe_hidden_states.reshape(num_tokens, experts_per_tok, hidden_size) * expert_weights\n        )\n        # moe_hidden_states: [num_tokens, hidden_size]\n        moe_hidden_states = op_ext.moe_misc.moe_sum(moe_hidden_states, dim=1)\n\n        shared_expert_hidden_states = self.shared_expert(x)\n        shared_expert_hidden_states = (\n            op.sigmoid(self.shared_expert_gate(x)) * shared_expert_hidden_states\n        )\n        final_hidden_states = moe_hidden_states + shared_expert_hidden_states\n        final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_size)\n        return final_hidden_states\n\n\nclass Qwen2MoeDecoderLayer(nn.Module):\n    def __init__(self, config: Qwen2MoeConfig):\n        super().__init__()\n        self.self_attn = QWen2Attention(config)\n        assert (\n            config.num_experts > 0 and config.decoder_sparse_step == 1\n        ), \"Currently only support use moe for every layer.\"\n        self.mlp = Qwen2MoeSparseMoeBlock(config)\n        self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n        self.post_attention_layernorm = nn.RMSNorm(\n            config.hidden_size, -1, config.rms_norm_eps, bias=False\n        )\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_attention_heads * hd\n            k = self.self_attn.num_key_value_heads * hd\n            v = self.self_attn.num_key_value_heads * hd\n            si = self.mlp.shared_expert.intermediate_size\n            mi = self.mlp.moe_intermediate_size\n            _set(\n                self.self_attn.c_attn.weight,\n                tp.ShardSingleDim(\"_shard_qkv_weight\", dim=0, segs=[q, k, v]),\n            )\n            _set(\n                self.self_attn.c_attn.bias,\n                tp.ShardSingleDim(\"_shard_qkv_bias\", dim=0, segs=[q, k, v]),\n            )\n            _set(self.self_attn.o_proj.weight, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.shared_expert.gate_up_proj.weight,\n                tp.ShardSingleDim(\"_shard_shared_mlp_up\", segs=[si, si], dim=0),\n            )\n            _set(\n                self.mlp.shared_expert.down_proj.weight,\n                tp.ShardSingleDim(\"_shard_shared_mlp_down\", dim=1),\n            )\n            _set(\n                self.mlp.moe_gate_up_proj.weight,\n                tp.ShardSingleDim(\"_shard_moe_mlp_up\", segs=[mi, mi], dim=1),\n            )\n            _set(\n                self.mlp.moe_down_proj.weight,\n                tp.ShardSingleDim(\"_shard_moe_mlp_down\", dim=2),\n            )\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.input_layernorm(hidden_states)\n        out = self.self_attn(out, paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.post_attention_layernorm(hidden_states)\n        out = self.mlp(out)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass Qwen2MoeModel(nn.Module):\n    def __init__(self, config: Qwen2MoeConfig):\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)\n        self.layers = nn.ModuleList(\n            [Qwen2MoeDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n\n    def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = inputs\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass Qwen2MoeForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: Qwen2MoeConfig):\n        self.model = Qwen2MoeModel(config)\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.dtype = config.dtype\n        self.hidden_size = config.hidden_size\n        self.num_hidden_layers = config.num_hidden_layers\n        self.intermediate_size = config.intermediate_size\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.rms_norm_eps = config.rms_norm_eps\n        self.rope_theta = config.rope_theta\n        self.vocab_size = config.vocab_size\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.head_dim = config.head_dim\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/qwen3/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/qwen3/qwen3_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's QWen2 parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\nfrom typing import Callable, List, Literal\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping, QuantizeMapping\nfrom mlc_llm.quantization import BlockScaleQuantize, Quantization\n\nfrom .qwen3_model import Qwen3Config, Qwen3LMHeadModel\n\n\ndef huggingface(\n    model_config: Qwen3Config,\n    quantization: Quantization,\n    hf_prefix: Literal[\"\", \"model.\"] = \"model.\",\n) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : Qwen3Config\n        The configuration of the Qwen3 model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    hf_prefix : Literal[\"\", \"model.\"]\n        Prefix used in HuggingFace weight names. Defaults to \"model.\" for standard\n        Qwen3 models. Use \"\" for Qwen3-Embedding models without prefix.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = Qwen3LMHeadModel(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    if isinstance(quantization, BlockScaleQuantize):\n        model = quantization.quantize_model(model, QuantizeMapping({}, {}), \"\")\n        if model_config.weight_block_size is None:\n            raise ValueError(\n                \"The input Qwen3 model is not fp8 block quantized. \"\n                \"Thus BlockScaleQuantize is not supported.\"\n            )\n\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    if (\n        not isinstance(quantization, BlockScaleQuantize)\n        and model_config.weight_block_size is not None\n    ):\n        raise ValueError(\n            \"The input Qwen3 model is fp8 block quantized. \"\n            \"Please use BlockScaleQuantize for the model.\"\n        )\n\n    def to_hf(name: str) -> str:\n        if hf_prefix == \"model.\":\n            return name\n        return name[6:] if name.startswith(\"model.\") else name\n\n    def add_weight_and_scale_mapping(\n        weight_mlc_name: str,\n        weight_hf_names: List[str],\n        weight_transform_func: Callable,\n    ):\n        mlc_param = named_parameters[weight_mlc_name]\n        hf_names = [to_hf(name) for name in weight_hf_names]\n        mapping.add_mapping(\n            weight_mlc_name,\n            hf_names,\n            functools.partial(weight_transform_func, dtype=mlc_param.dtype),\n        )\n\n        if isinstance(quantization, BlockScaleQuantize):\n            scale_mlc_name = f\"{weight_mlc_name}_scale_inv\"\n            if scale_mlc_name in named_parameters:\n                scale_hf_names = [f\"{name}_scale_inv\" for name in hf_names]\n                scale_param = named_parameters[scale_mlc_name]\n                mapping.add_mapping(\n                    scale_mlc_name,\n                    scale_hf_names,\n                    functools.partial(weight_transform_func, dtype=scale_param.dtype),\n                )\n\n    for i in range(model_config.num_hidden_layers):\n        # map attention weight\n        attn = f\"model.layers.{i}.self_attn\"\n        add_weight_and_scale_mapping(\n            f\"{attn}.c_attn.weight\",\n            [\n                f\"{attn}.q_proj.weight\",\n                f\"{attn}.k_proj.weight\",\n                f\"{attn}.v_proj.weight\",\n            ],\n            lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),\n        )\n        if model_config.attention_bias:\n            mlc_name = f\"{attn}.c_attn.bias\"\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    to_hf(f\"{attn}.q_proj.bias\"),\n                    to_hf(f\"{attn}.k_proj.bias\"),\n                    to_hf(f\"{attn}.v_proj.bias\"),\n                ],\n                functools.partial(\n                    lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n        # map mlp weight\n        mlp = f\"model.layers.{i}.mlp\"\n        add_weight_and_scale_mapping(\n            f\"{mlp}.gate_up_proj.weight\",\n            [\n                f\"{mlp}.gate_proj.weight\",\n                f\"{mlp}.up_proj.weight\",\n            ],\n            lambda gate, up, dtype: np.concatenate([gate, up], axis=0).astype(dtype),\n        )\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [to_hf(mlc_name)],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n    return mapping\n\n\ndef huggingface_embedding(model_config: Qwen3Config, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping for Qwen3-Embedding models (no 'model.' prefix).\"\"\"\n    return huggingface(model_config, quantization, \"\")\n"
  },
  {
    "path": "python/mlc_llm/model/qwen3/qwen3_model.py",
    "content": "\"\"\"\nImplementation for QWEN2 architecture.\n\"\"\"\n\nimport dataclasses\nfrom functools import partial\nfrom typing import Any, Dict, Optional, Tuple\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass Qwen3Config(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Qwen3 model.\"\"\"\n\n    hidden_act: str\n    hidden_size: int\n    intermediate_size: int\n    attention_bias: bool\n    num_attention_heads: int\n    num_hidden_layers: int\n    num_key_value_heads: int\n    rms_norm_eps: float\n    rope_theta: int\n    vocab_size: int\n    tie_word_embeddings: bool = False\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    head_dim: int = 0\n    dtype: str = \"float32\"\n    max_batch_size: int = 1\n    weight_block_size: Optional[Tuple[int, int]] = None\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if \"quantization_config\" in self.kwargs:\n            quantization_config = self.kwargs.get(\"quantization_config\")\n            if (\n                isinstance(quantization_config, dict)\n                and quantization_config.get(\"activation_scheme\", \"\") == \"dynamic\"\n                and quantization_config.get(\"fmt\", \"\") == \"e4m3\"\n                and quantization_config.get(\"quant_method\", \"\") == \"fp8\"\n                and \"weight_block_size\" in quantization_config\n            ):\n                self.weight_block_size = quantization_config.get(\"weight_block_size\")\n                if (\n                    not isinstance(self.weight_block_size, (tuple, list))\n                    or len(self.weight_block_size) != 2\n                ):\n                    raise ValueError(\n                        \"Invalid DeepSeek model quantization config: \"\n                        \"weight_block_size must be a tuple of two integers, \"\n                        f\"got {self.weight_block_size} of type {type(self.weight_block_size)}\"\n                    )\n            else:\n                raise ValueError(\n                    \"Invalid DeepSeek model quantization config: unrecognized quantization config: \"\n                    f\"{quantization_config}\"\n                )\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 2048),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 2048)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 2048),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 2048)\n\n\n# pylint: disable=invalid-name,missing-docstring,too-many-locals\n\n\nclass Qwen3Attention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: Qwen3Config):\n        self.head_dim = config.head_dim\n        if config.num_key_value_heads % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.num_key_value_heads} key-value attention heads \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.num_attention_heads = config.num_attention_heads // config.tensor_parallel_shards\n        self.num_key_value_heads = config.num_key_value_heads // config.tensor_parallel_shards\n        self.rope_theta = config.rope_theta\n\n        self.c_attn = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=(2 * self.num_key_value_heads + self.num_attention_heads) * self.head_dim,\n            bias=config.attention_bias,\n        )\n        self.o_proj = nn.Linear(\n            self.num_attention_heads * self.head_dim,\n            config.hidden_size,\n            bias=config.attention_bias,\n        )\n        self.q_norm = nn.RMSNorm(config.head_dim, -1, config.rms_norm_eps, bias=False)\n        self.k_norm = nn.RMSNorm(config.head_dim, -1, config.rms_norm_eps, bias=False)\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q, h_kv = self.head_dim, self.num_attention_heads, self.num_key_value_heads\n        b, s, _ = hidden_states.shape\n        qkv = self.c_attn(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        q, k, v = op.split(qkv, [h_q, h_q + h_kv], axis=2)\n        q = self.q_norm(q)\n        k = self.k_norm(k)\n        qkv = op.concat([q, k, v], dim=2)\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_attention_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        attn_output = self.o_proj(output)\n        return attn_output\n\n\nACT2FN = {\n    \"gelu\": partial(nn.gelu, approximate=False),\n    \"relu\": nn.relu,\n    \"silu\": nn.silu,\n    \"swish\": nn.silu,\n    \"gelu_new\": partial(nn.gelu, approximate=True),\n}\n\n\nclass Qwen3Embedding(nn.Embedding):\n    \"\"\"The embedding module specialized for Qwen3 so that\n    it can be shared with the final lm_head.\n    \"\"\"\n\n    def lm_head_forward(self, x: nn.Tensor):\n        \"\"\"The lm_head forwarding, which transposes the weight and multiplies\n        with the input tensor.\n        \"\"\"\n        weight = nn.op.permute_dims(self.weight)\n        return nn.op.matmul(x, weight, out_dtype=\"float32\")\n\n\nclass Qwen3MLP(nn.Module):\n    def __init__(self, config: Qwen3Config):\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.gate_up_proj = nn.Linear(config.hidden_size, 2 * self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x: Tensor):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(self.act_fn(x1) * x2)\n\n\nclass Qwen3DecoderLayer(nn.Module):\n    def __init__(self, config: Qwen3Config):\n        self.self_attn = Qwen3Attention(config)\n        self.mlp = Qwen3MLP(config)\n        self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n        self.post_attention_layernorm = nn.RMSNorm(\n            config.hidden_size, -1, config.rms_norm_eps, bias=False\n        )\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_attention_heads * hd\n            k = self.self_attn.num_key_value_heads * hd\n            v = self.self_attn.num_key_value_heads * hd\n            i = self.mlp.intermediate_size\n            _set(\n                self.self_attn.c_attn.weight,\n                tp.ShardSingleDim(\"_shard_qkv_weight\", dim=0, segs=[q, k, v]),\n            )\n            if config.attention_bias:\n                _set(\n                    self.self_attn.c_attn.bias,\n                    tp.ShardSingleDim(\"_shard_qkv_bias\", dim=0, segs=[q, k, v]),\n                )\n            _set(self.self_attn.o_proj.weight, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.gate_up_proj.weight,\n                tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n            )\n            _set(self.mlp.down_proj.weight, tp.ShardSingleDim(\"_shard_mlp_down\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.input_layernorm(hidden_states)\n        out = self.self_attn(out, paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.post_attention_layernorm(hidden_states)\n        out = self.mlp(out)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass Qwen3Model(nn.Module):\n    def __init__(self, config: Qwen3Config):\n        self.embed_tokens = Qwen3Embedding(config.vocab_size, config.hidden_size)\n        self.layers = nn.ModuleList(\n            [Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n\n    def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = inputs\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass Qwen3LMHeadModel(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: Qwen3Config):\n        self.model = Qwen3Model(config)\n        self.tie_word_embeddings = config.tie_word_embeddings\n        if not config.tie_word_embeddings:\n            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.dtype = config.dtype\n        self.hidden_size = config.hidden_size\n        self.num_hidden_layers = config.num_hidden_layers\n        self.intermediate_size = config.intermediate_size\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.rms_norm_eps = config.rms_norm_eps\n        self.rope_theta = config.rope_theta\n        self.vocab_size = config.vocab_size\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.head_dim = config.head_dim\n        self.weight_block_size = config.weight_block_size\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n\n        if self.tie_word_embeddings:\n            logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        if self.tie_word_embeddings:\n            logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        if self.tie_word_embeddings:\n            logits = self.model.embed_tokens.lm_head_forward(hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n\n\nclass Qwen3EmbeddingModel(Qwen3LMHeadModel):\n    \"\"\"Qwen3 model for embedding inference.\n\n    Inherits all functionality from Qwen3LMHeadModel and adds methods that\n    return hidden states instead of logits, for use by AsyncEmbeddingEngine.\n    Only compiled when using the \"qwen3-embedding\" model type.\n    \"\"\"\n\n    def prefill_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def decode_to_last_hidden_states(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def batch_prefill_to_last_hidden_states(\n        self, input_embeds: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        op_ext.configure()\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def batch_decode_to_last_hidden_states(\n        self, input_embeds: Tensor, paged_kv_cache: PagedKVCache\n    ):\n        op_ext.configure()\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        return hidden_states, paged_kv_cache\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill_to_last_hidden_states\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode_to_last_hidden_states\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill_to_last_hidden_states\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode_to_last_hidden_states\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/qwen3_moe/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/qwen3_moe/qwen3_moe_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's QWen2 parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\nfrom typing import Callable, List\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping, QuantizeMapping\nfrom mlc_llm.quantization import BlockScaleQuantize, Quantization\n\nfrom .qwen3_moe_model import Qwen3MoeConfig, Qwen3MoeForCausalLM\n\n\ndef huggingface(model_config: Qwen3MoeConfig, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : Qwen3MoeConfig\n        The configuration of the Qwen3Moe model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = Qwen3MoeForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    if isinstance(quantization, BlockScaleQuantize):\n        # Convert the model to block-scale quantized model before loading parameters\n        model = quantization.quantize_model(model, QuantizeMapping({}, {}), \"\")\n        if model_config.weight_block_size is None:\n            raise ValueError(\n                \"The input Qwen3 model is not fp8 block quantized. \"\n                \"Thus BlockScaleQuantize is not supported.\"\n            )\n\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    if (\n        not isinstance(quantization, BlockScaleQuantize)\n        and model_config.weight_block_size is not None\n    ):\n        raise ValueError(\n            \"The input Qwen3 model is fp8 block quantized. \"\n            \"Please use BlockScaleQuantize for the model.\"\n        )\n\n    # Helper function to add both weight and scale mappings\n    def add_weight_and_scale_mapping(\n        weight_mlc_name: str,\n        weight_hf_names: List[str],\n        weight_transform_func: Callable,\n    ):\n        mlc_param = named_parameters[weight_mlc_name]\n        mapping.add_mapping(\n            weight_mlc_name,\n            weight_hf_names,\n            functools.partial(weight_transform_func, dtype=mlc_param.dtype),\n        )\n\n        if isinstance(quantization, BlockScaleQuantize):\n            scale_mlc_name = f\"{weight_mlc_name}_scale_inv\"\n            if scale_mlc_name in named_parameters:\n                scale_hf_names = [f\"{name}_scale_inv\" for name in weight_hf_names]\n                scale_param = named_parameters[scale_mlc_name]\n                mapping.add_mapping(\n                    scale_mlc_name,\n                    scale_hf_names,\n                    functools.partial(weight_transform_func, dtype=scale_param.dtype),\n                )\n\n    for i in range(model_config.num_hidden_layers):\n        # map attention weight\n        attn = f\"model.layers.{i}.self_attn\"\n        add_weight_and_scale_mapping(\n            f\"{attn}.c_attn.weight\",\n            [\n                f\"{attn}.q_proj.weight\",\n                f\"{attn}.k_proj.weight\",\n                f\"{attn}.v_proj.weight\",\n            ],\n            lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),\n        )\n        if model_config.attention_bias:\n            mlc_name = f\"{attn}.c_attn.bias\"\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{attn}.q_proj.bias\",\n                    f\"{attn}.k_proj.bias\",\n                    f\"{attn}.v_proj.bias\",\n                ],\n                functools.partial(\n                    lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n        # map mlp moe gate and up weight\n        mlp = f\"model.layers.{i}.mlp\"\n\n        def combine_expert_gate_up(*hf_params, dtype):\n            stack = []\n            for i in range(0, len(hf_params), 2):\n                stack.append(np.concatenate([hf_params[i], hf_params[i + 1]], axis=0))\n            return np.stack(stack, axis=0).astype(dtype)\n\n        add_weight_and_scale_mapping(\n            f\"{mlp}.moe_gate_up_proj.weight\",\n            functools.reduce(\n                lambda a, b: a + b,\n                [\n                    [\n                        f\"{mlp}.experts.{expert_id}.gate_proj.weight\",\n                        f\"{mlp}.experts.{expert_id}.up_proj.weight\",\n                    ]\n                    for expert_id in range(model_config.num_experts)\n                ],\n            ),\n            combine_expert_gate_up,\n        )\n\n        # map mlp moe down projection weight\n        add_weight_and_scale_mapping(\n            f\"{mlp}.moe_down_proj.weight\",\n            [\n                f\"{mlp}.experts.{expert_id}.down_proj.weight\"\n                for expert_id in range(model_config.num_experts)\n            ],\n            lambda *hf_params, dtype: np.stack(hf_params, axis=0).astype(dtype),\n        )\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/qwen3_moe/qwen3_moe_model.py",
    "content": "\"\"\"\nImplementation for QWEN2MOE architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.model.qwen3.qwen3_model import ACT2FN, Qwen3Attention, Qwen3Config\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.nn.expert import MixtralExperts\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass Qwen3MoeConfig(Qwen3Config):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Qwen3Moe model.\"\"\"\n\n    moe_intermediate_size: int = 0\n    num_experts_per_tok: int = 0\n    num_experts: int = 0\n    decoder_sparse_step: int = 0\n    norm_topk_prob: bool = False\n\n\n# pylint: disable=invalid-name,missing-docstring,too-many-locals\n\n\nclass Qwen3MoeMLP(nn.Module):\n    def __init__(self, config: Qwen3MoeConfig, intermediate_size: Optional[int] = None):\n        intermediate_size = intermediate_size or config.intermediate_size\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MoE MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = intermediate_size // config.tensor_parallel_shards\n        self.gate_up_proj = nn.Linear(config.hidden_size, 2 * self.intermediate_size, bias=False)\n        self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)\n        self.act_fn = ACT2FN[config.hidden_act]\n\n    def forward(self, x: Tensor):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(self.act_fn(x1) * x2)\n\n\nclass Qwen3MoeSparseMoeBlock(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"MoE layer for Qwen3MoE model.\"\"\"\n\n    def __init__(self, config: Qwen3MoeConfig):\n        super().__init__()\n        self.num_experts_per_tok = config.num_experts_per_tok\n        self.num_experts = config.num_experts\n        if config.moe_intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MoE intermediate size {config.moe_intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.moe_intermediate_size = config.moe_intermediate_size // config.tensor_parallel_shards\n        self.norm_topk_prob = config.norm_topk_prob\n\n        self.gate = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=config.num_experts,\n            bias=False,\n        )\n        self.moe_gate_up_proj = MixtralExperts(\n            self.num_experts,\n            in_features=config.hidden_size,\n            out_features=2 * self.moe_intermediate_size,\n        )\n        self.moe_down_proj = MixtralExperts(\n            self.num_experts,\n            in_features=self.moe_intermediate_size,\n            out_features=config.hidden_size,\n        )\n        self.act_fn = ACT2FN[config.hidden_act]\n        self.dtype = \"float32\"\n\n    def forward(self, x: Tensor):\n        def _expert_forward(x: Tensor, indptr: Tensor):\n            x1_x2 = self.moe_gate_up_proj(x, indptr)\n            x1, x2 = op.split(x1_x2, indices_or_sections=2, axis=-1)\n            x = self.moe_down_proj(self.act_fn(x1) * x2, indptr)\n            return x\n\n        experts_per_tok = self.num_experts_per_tok\n        num_experts = self.num_experts\n        batch_size, seq_len, hidden_size = x.shape\n        num_tokens = batch_size * seq_len\n        x = x.reshape(num_tokens, hidden_size)\n        gate = self.gate(x)\n        # expert_weights: [num_tokens, experts_per_tok]\n        # expert_indices: [num_tokens, experts_per_tok]\n        expert_weights, expert_indices = op_ext.moe_misc.gating_softmax_topk(\n            gate, experts_per_tok, norm_topk_prob=self.norm_topk_prob\n        )\n        use_cutlass = op_ext.get_store().cutlass_group_gemm and self.dtype in [\n            \"float16\",\n            \"bfloat16\",\n        ]\n        if num_tokens == 1:\n            # x: [num_tokens * experts_per_tok, hidden_size]\n            moe_hidden_states = _expert_forward(x, expert_indices)\n        else:\n            # cumsum: [num_tokens * local_experts]\n            cumsum = op_ext.moe_misc.moe_cumsum(expert_indices, num_experts)\n            # indices: [num_tokens * experts_per_tok]\n            reverse_indices, token_indices = op_ext.moe_misc.get_indices(cumsum, expert_indices)\n            # indptr: [num_local_experts + 1]\n            if use_cutlass:\n                # indptr: [num_experts]\n                indptr = op_ext.moe_misc.get_indptr(\n                    cumsum, num_experts, num_tokens, inclusive=True, out_dtype=\"int64\"\n                )\n            else:\n                # indptr: [num_experts + 1]\n                indptr = op_ext.moe_misc.get_indptr(\n                    cumsum, num_experts, num_tokens, inclusive=False, out_dtype=\"int32\"\n                )\n            # x: [num_tokens * experts_per_tok, hidden_size]\n            moe_hidden_states = op.take(x, token_indices, axis=0)\n            moe_hidden_states = _expert_forward(moe_hidden_states, indptr)\n            moe_hidden_states = op_ext.moe_misc.scatter_output(moe_hidden_states, reverse_indices)\n        # moe_hidden_states: [num_tokens, experts_per_tok, hidden_size]\n        expert_weights = expert_weights.reshape(num_tokens, experts_per_tok, 1)\n        moe_hidden_states = (\n            moe_hidden_states.reshape(num_tokens, experts_per_tok, hidden_size) * expert_weights\n        )\n        # moe_hidden_states: [num_tokens, hidden_size]\n        moe_hidden_states = op_ext.moe_misc.moe_sum(moe_hidden_states, dim=1)\n\n        final_hidden_states = moe_hidden_states\n        final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_size)\n        return final_hidden_states\n\n\nclass Qwen3MoeDecoderLayer(nn.Module):\n    def __init__(self, config: Qwen3MoeConfig):\n        super().__init__()\n        self.self_attn = Qwen3Attention(config)\n        assert (\n            config.num_experts > 0 and config.decoder_sparse_step == 1\n        ), \"Currently only support use moe for every layer.\"\n        self.mlp = Qwen3MoeSparseMoeBlock(config)\n        self.input_layernorm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n        self.post_attention_layernorm = nn.RMSNorm(\n            config.hidden_size, -1, config.rms_norm_eps, bias=False\n        )\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_attention_heads * hd\n            k = self.self_attn.num_key_value_heads * hd\n            v = self.self_attn.num_key_value_heads * hd\n            mi = self.mlp.moe_intermediate_size\n            _set(\n                self.self_attn.c_attn.weight,\n                tp.ShardSingleDim(\"_shard_qkv_weight\", dim=0, segs=[q, k, v]),\n            )\n            if config.attention_bias:\n                _set(\n                    self.self_attn.c_attn.bias,\n                    tp.ShardSingleDim(\"_shard_qkv_bias\", dim=0, segs=[q, k, v]),\n                )\n            _set(self.self_attn.o_proj.weight, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.moe_gate_up_proj.weight,\n                tp.ShardSingleDim(\"_shard_moe_mlp_up\", segs=[mi, mi], dim=1),\n            )\n            _set(\n                self.mlp.moe_down_proj.weight,\n                tp.ShardSingleDim(\"_shard_moe_mlp_down\", dim=2),\n            )\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.input_layernorm(hidden_states)\n        out = self.self_attn(out, paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.post_attention_layernorm(hidden_states)\n        out = self.mlp(out)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass Qwen3MoeModel(nn.Module):\n    def __init__(self, config: Qwen3MoeConfig):\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)\n        self.layers = nn.ModuleList(\n            [Qwen3MoeDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = nn.RMSNorm(config.hidden_size, -1, config.rms_norm_eps, bias=False)\n\n    def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = inputs\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass Qwen3MoeForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: Qwen3MoeConfig):\n        self.model = Qwen3MoeModel(config)\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.dtype = config.dtype\n        self.hidden_size = config.hidden_size\n        self.num_hidden_layers = config.num_hidden_layers\n        self.intermediate_size = config.intermediate_size\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.rms_norm_eps = config.rms_norm_eps\n        self.rope_theta = config.rope_theta\n        self.vocab_size = config.vocab_size\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.head_dim = config.head_dim\n        self.weight_block_size = config.weight_block_size\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/rwkv5/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/rwkv5/rwkv5_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's RWKV5 parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nimport numpy as np\n\nfrom ...loader import ExternMapping\nfrom ...quantization import Quantization\nfrom .rwkv5_model import RWKV5_ForCausalLM, RWKV5Config\n\n\ndef huggingface(model_config: RWKV5Config, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : RWKVConfig\n        The configuration of the RWKV5 model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = RWKV5_ForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params = model.export_tvm(  # pylint: disable=unbalanced-tuple-unpacking\n        spec=model.get_default_spec()\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    for i in range(model_config.num_hidden_layers):\n        # convert time_decay\n        mlc_name = f\"model.blocks.{i}.attention.time_decay\"\n        hf_name = f\"rwkv.blocks.{i}.attention.time_decay\"\n        mlc_param = named_parameters[mlc_name]\n        if mlc_param.dtype != \"float32\":\n            raise ValueError(f\"RWKV5 time_decay should be float32, got {mlc_param.dtype}\")\n        mapping.add_mapping(\n            mlc_name,\n            [hf_name],\n            functools.partial(\n                lambda x, dtype: np.exp(-np.exp(x.astype(dtype))),\n                dtype=mlc_param.dtype,\n            ),\n        )\n\n        # rescale\n        if model_config.rescale_every > 0:\n            for name in [\"feed_forward.value.weight\", \"attention.output.weight\"]:\n                mlc_name = f\"model.blocks.{i}.{name}\"\n                hf_name = f\"rwkv.blocks.{i}.{name}\"\n                mlc_param = named_parameters[mlc_name]\n\n                mapping.add_mapping(\n                    mlc_name,\n                    [hf_name],\n                    functools.partial(\n                        lambda x, dtype, t: x.astype(dtype) / (2**t),\n                        dtype=mlc_param.dtype,\n                        t=i // model_config.rescale_every,\n                    ),\n                )\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            hf_name = mlc_name.replace(\"model\", \"rwkv\")\n            mapping.add_mapping(\n                mlc_name,\n                [hf_name],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/rwkv5/rwkv5_model.py",
    "content": "\"\"\"Implementation for RWKV5 architecture.\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Object, Tensor, op\nfrom tvm.script import tir as T\n\nfrom mlc_llm.nn.rnn_state import RNNState\nfrom mlc_llm.support import logging\nfrom mlc_llm.support.config import ConfigBase\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass StateID:\n    \"\"\"State ID for RWKV5.\"\"\"\n\n    ATT_X = 0\n    ATT_KV = 1\n    FFN_X = 2\n\n\n@dataclasses.dataclass\nclass RWKV5Config(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the RWKV5 model.\"\"\"\n\n    hidden_size: int\n    intermediate_size: int\n    num_hidden_layers: int\n    vocab_size: int\n    model_version: str\n    tensor_parallel_shards: int = 1\n    rescale_every: int = 0\n    head_size: int = 64\n    layer_norm_epsilon: float = 1e-5\n    context_window_size: int = -1  # RWKV does not have context window limitation.\n    prefill_chunk_size: int = 4096\n    num_heads: int = 0\n    max_batch_size: int = 1\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.model_version != \"5_2\":\n            raise ValueError(f\"Only support RWKV v5_2, got {self.model_version}.\")\n        self.intermediate_size = self.intermediate_size or int((self.hidden_size * 3.5)) // 32 * 32\n        self.num_heads = (\n            self.hidden_size // self.head_size if self.num_heads == 0 else self.num_heads\n        )\n        if self.num_heads * self.head_size != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size ({self.hidden_size}) must be divisible \"\n                f\"by head_size ({self.head_size})\"\n            )\n        if self.tensor_parallel_shards != 1:\n            raise ValueError(\"Only support single device at this moment.\")\n\n\n# pylint: disable=invalid-name,missing-docstring\n# pylint: disable=too-many-arguments, too-many-locals, redefined-argument-from-local\ndef create_wkv5_func(\n    num_heads: int,\n    head_size: int,\n    dtype: str,\n    out_dtype: str,\n    state_dtype: str,\n):\n    @T.prim_func\n    def wkv_func(\n        r: T.handle,\n        k: T.handle,\n        v: T.handle,\n        time_decay: T.handle,\n        time_faaaa: T.handle,\n        state: T.handle,\n        out: T.handle,\n        out_state: T.handle,\n    ):\n        T.func_attr({\"op_pattern\": 8, \"tir.noalias\": True, \"tir.is_scheduled\": 1})\n        batch_size, seq_len = T.int64(), T.int64()\n        # Inputs\n        r_buf = T.match_buffer(r, (batch_size, seq_len, num_heads, head_size), dtype=dtype)\n        k_buf = T.match_buffer(k, (batch_size, seq_len, num_heads, head_size), dtype=dtype)\n        v_buf = T.match_buffer(v, (batch_size, seq_len, num_heads, head_size), dtype=dtype)\n        time_decay_buf = T.match_buffer(time_decay, (num_heads, head_size), dtype=\"float32\")\n        time_faaaa_buf = T.match_buffer(time_faaaa, (num_heads, head_size), dtype=\"float32\")\n        state_buf = T.match_buffer(\n            state, (batch_size, num_heads, head_size, head_size), dtype=state_dtype\n        )\n        # Outputs\n        out_buf = T.match_buffer(out, (batch_size, seq_len, num_heads, head_size), dtype=out_dtype)\n        out_state_buf = T.match_buffer(\n            out_state, (batch_size, num_heads, head_size, head_size), dtype=state_dtype\n        )\n        for b in T.thread_binding(batch_size, thread=\"blockIdx.y\"):\n            for h in T.thread_binding(num_heads, thread=\"blockIdx.x\"):\n                for i in T.thread_binding(head_size, thread=\"threadIdx.x\"):\n                    for j in range(head_size):\n                        with T.sblock(\"init_state\"):\n                            vb, vh, vi, vj = T.axis.remap(\"SSSS\", [b, h, i, j])\n                            out_state_buf[vb, vh, vi, vj] = state_buf[vb, vh, vi, vj]\n\n                    for t in range(seq_len):\n                        with T.sblock(\"comput\"):\n                            vb = T.axis.spatial(batch_size, b)\n                            vt = T.axis.opaque(seq_len, t)\n                            vh = T.axis.spatial(num_heads, h)\n                            vi = T.axis.spatial(head_size, i)\n                            out_buf[vb, vt, vh, vi] = 0\n\n                            for k in range(head_size):\n                                x = k_buf[vb, vt, vh, k] * v_buf[vb, vt, vh, vi]\n                                out_buf[vb, vt, vh, vi] += T.cast(\n                                    r_buf[vb, vt, vh, k], out_dtype\n                                ) * T.cast(\n                                    time_faaaa_buf[vh, k] * x + out_state_buf[vb, vh, vi, k],\n                                    out_dtype,\n                                )\n                                out_state_buf[vb, vh, vi, k] = (\n                                    out_state_buf[vb, vh, vi, k] * time_decay_buf[vh, k] + x\n                                )\n\n    return wkv_func\n\n\n# pylint: enable=too-many-arguments, too-many-locals\n\n\ndef token_shift(state: Tensor, x: Tensor):\n    def _te_token_shift(state: te.Tensor, x: te.Tensor):\n        return te.compute(\n            x.shape,\n            lambda b, i, j: tir.if_then_else(i == 0, state[b, j], x[b, i - 1, j]),\n        )\n\n    return op.tensor_expr_op(_te_token_shift, \"token_shift\", [state, x])\n\n\ndef last_token(x: Tensor):\n    # x.shape = (batch, seq_len, hidden_size)\n    batch, seq_len, hidden_size = x.shape\n\n    def _te_last_token(x: te.Tensor):\n        return te.compute((batch, 1, hidden_size), lambda b, _, j: x[b, x.shape[1] - 1, j])\n\n    return x if seq_len == 1 else op.tensor_expr_op(_te_last_token, \"last_token\", [x])\n\n\nclass RWKV5_FNN(nn.Module):\n    def __init__(self, config: RWKV5Config, layer_id: int):\n        super().__init__()\n        self.time_mix_key = nn.Parameter((1, 1, config.hidden_size))\n        self.time_mix_receptance = nn.Parameter((1, 1, config.hidden_size))\n        self.key = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)\n        self.receptance = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n        self.value = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)\n        self.layer_id = layer_id\n\n    def forward(self, x: Tensor, state: RNNState):\n        batch, _, hidden_size = x.shape\n        state_x = state.get(self.layer_id, StateID.FFN_X, (batch, hidden_size), x.dtype)\n        state_x = token_shift(state_x, x)\n        xk = x * self.time_mix_key + state_x * (1.0 - self.time_mix_key)\n        xr = x * self.time_mix_receptance + state_x * (1.0 - self.time_mix_receptance)\n        last_x = last_token(x).reshape(batch, hidden_size)\n        state = state.set(self.layer_id, StateID.FFN_X, last_x)\n        r = op.sigmoid(self.receptance(xr))\n        xv = op.square(op.relu(self.key(xk)))\n        return r * self.value(xv), state\n\n\nclass RWKV5_Attention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Attention layer for RWKV.\"\"\"\n\n    def __init__(self, config: RWKV5Config, layer_id: int):\n        super().__init__()\n        self.time_decay = nn.Parameter((config.num_heads, config.head_size))\n        self.time_faaaa = nn.Parameter((config.num_heads, config.head_size))\n\n        self.time_mix_gate = nn.Parameter((1, 1, config.hidden_size))\n        self.time_mix_key = nn.Parameter((1, 1, config.hidden_size))\n        self.time_mix_value = nn.Parameter((1, 1, config.hidden_size))\n        self.time_mix_receptance = nn.Parameter((1, 1, config.hidden_size))\n\n        self.key = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n        self.value = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n        self.receptance = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n        self.gate = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n        self.output = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n        self.ln_x = nn.GroupNorm(\n            config.num_heads,\n            config.hidden_size,\n        )\n        self.hidden_size = config.hidden_size\n        self.head_size = config.head_size\n        self.num_heads = config.num_heads\n        self.layer_id = layer_id\n        self.dtype = \"float32\"\n\n    def forward(self, x: Tensor, state: RNNState):  # pylint: disable=too-many-locals\n        batch, seq_len, hidden_size = x.shape\n        assert hidden_size == self.hidden_size\n        B, T, H, N = (  # pylint: disable=redefined-outer-name\n            batch,\n            seq_len,\n            self.head_size,\n            self.num_heads,\n        )\n        x_state = state.get(self.layer_id, StateID.ATT_X, (batch, self.hidden_size), x.dtype)\n        x_state = token_shift(x_state, x)\n        kv_state = state.get(\n            self.layer_id,\n            StateID.ATT_KV,\n            (batch, self.num_heads, self.head_size, self.head_size),\n            \"float32\",  # Always use float32 for state KV.\n        )\n\n        xk = x * self.time_mix_key + x_state * (1.0 - self.time_mix_key)\n        xv = x * self.time_mix_value + x_state * (1.0 - self.time_mix_value)\n        xr = x * self.time_mix_receptance + x_state * (1.0 - self.time_mix_receptance)\n        xg = x * self.time_mix_gate + x_state * (1.0 - self.time_mix_gate)\n\n        r = op.reshape(self.receptance(xr), (B, T, N, H))\n        k = op.reshape(self.key(xk), (B, T, N, H))\n        v = op.reshape(self.value(xv), (B, T, N, H))\n        g = op.silu(self.gate(xg))\n\n        out, kv_state = op.tensor_ir_op(\n            create_wkv5_func(\n                self.num_heads,\n                self.head_size,\n                dtype=self.dtype,\n                out_dtype=\"float32\",\n                state_dtype=\"float32\",\n            ),\n            \"wkv5\",\n            [r, k, v, self.time_decay, self.time_faaaa, kv_state],\n            [\n                Tensor.placeholder([B, T, N, H], \"float32\"),\n                Tensor.placeholder([B, N, H, H], \"float32\"),\n            ],\n        )\n\n        last_x = last_token(x).reshape(batch, hidden_size)\n        state = state.set(self.layer_id, StateID.ATT_X, last_x)\n        state = state.set(self.layer_id, StateID.ATT_KV, kv_state)\n        out = op.astype(self.ln_x(op.reshape(out, x.shape), channel_axis=-1, axes=[]), self.dtype)\n        return self.output(out * g), state\n\n    def to(self, dtype: Optional[str] = None):\n        # RWKV uses special dtype, so we need to convert it.\n        if dtype is not None:\n            self.dtype = dtype\n\n        self.time_mix_gate.to(dtype)\n        self.time_mix_key.to(dtype)\n        self.time_mix_value.to(dtype)\n        self.time_mix_receptance.to(dtype)\n        self.key.to(dtype)\n        self.value.to(dtype)\n        self.receptance.to(dtype)\n        self.gate.to(dtype)\n        self.output.to(dtype)\n\n        # These parameters are necessary to be converted to float32.\n        self.time_decay.to(\"float32\")\n        self.time_faaaa.to(\"float32\")\n        self.ln_x.to(\"float32\")\n\n\nclass RWKV5_Layer(nn.Module):\n    def __init__(self, config: RWKV5Config, layer_id: int):\n        super().__init__()\n        if layer_id == 0:\n            self.pre_ln = nn.LayerNorm(\n                config.hidden_size,\n                eps=config.layer_norm_epsilon,\n            )\n        self.ln1 = nn.LayerNorm(\n            config.hidden_size,\n            eps=config.layer_norm_epsilon,\n        )\n        self.ln2 = nn.LayerNorm(\n            config.hidden_size,\n            eps=config.layer_norm_epsilon,\n        )\n        self.attention = RWKV5_Attention(config, layer_id)\n        self.feed_forward = RWKV5_FNN(config, layer_id)\n        self.layer_id = layer_id\n        self.rescale_every = config.rescale_every\n\n    def forward(self, x: Tensor, state: RNNState) -> Tensor:\n        if self.layer_id == 0:\n            x = self.pre_ln(x)\n        att_x, state = self.attention(self.ln1(x), state)\n        x += att_x\n        ffn_x, state = self.feed_forward(self.ln2(x), state)\n        x += ffn_x\n        if self.rescale_every > 0 and (self.layer_id + 1) % self.rescale_every == 0:\n            x = x / 2.0\n        return x, state\n\n\nclass RWKV5_Model(nn.Module):\n    \"\"\"Exact same as LlamaModel.\"\"\"\n\n    def __init__(self, config: RWKV5Config):\n        super().__init__()\n        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)\n        self.blocks = nn.ModuleList(\n            [RWKV5_Layer(config, i) for i in range(config.num_hidden_layers)]\n        )\n        self.ln_out = nn.LayerNorm(\n            config.hidden_size,\n            eps=config.layer_norm_epsilon,\n        )\n\n    def forward(self, input_embed: Tensor, state: RNNState):\n        \"\"\"Forward pass of the model, passing through all decoder layers.\"\"\"\n        hidden_states = input_embed\n        for block in self.blocks:\n            hidden_states, state = block(hidden_states, state)\n        return self.ln_out(hidden_states), state\n\n\nclass RWKV5_ForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Same as LlamaForCausalLM, except for the use of sliding window attention.\"\"\"\n\n    def __init__(self, config: RWKV5Config):\n        self.model = RWKV5_Model(config)\n        self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.num_hidden_layers = config.num_hidden_layers\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_heads\n        self.head_size = config.head_size\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def embed(self, input_ids: Tensor):\n        return self.model.embeddings(input_ids)\n\n    def forward(\n        self,\n        input_embed: Tensor,\n        state: RNNState,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        \"\"\"Forward pass.\"\"\"\n        hidden_states, state = self.model(input_embed, state)\n        hidden_states = last_token(hidden_states)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, state\n\n    def prefill(self, input_embed: Tensor, state: RNNState):\n        \"\"\"Prefilling the prompt.\"\"\"\n        return self.forward(input_embed, state)\n\n    def decode(self, input_embed: Tensor, state: RNNState):\n        \"\"\"Decoding step.\"\"\"\n        return self.forward(input_embed, state)\n\n    def batch_prefill(self, input_embeds: Tensor, logit_positions: Tensor, state: RNNState):\n        \"\"\"Prefilling the prompt.\"\"\"\n        return self.forward(input_embeds, state, logit_positions=logit_positions)\n\n    def batch_decode(self, input_embeds: Tensor, state: RNNState):\n        \"\"\"Decoding step.\"\"\"\n        return self.forward(input_embeds, state)\n\n    def batch_verify(self, input_embeds: Tensor, state: RNNState):\n        \"\"\"Verify step.\"\"\"\n        return self.forward(input_embeds, state)\n\n    def create_rnn_state(\n        self,\n        max_batch_size: tir.Var,\n        max_history: tir.Var,\n    ) -> Object:\n        \"\"\"Create RNN state.\"\"\"\n        init_values = [\n            op.zeros((self.hidden_size,), dtype=self.dtype),  # ATT_X\n            op.zeros((self.num_heads, self.head_size, self.head_size), dtype=\"float32\"),  # ATT_KV\n            op.zeros((self.hidden_size,), dtype=self.dtype),  # FFN_X\n        ]\n        return RNNState.create(\n            max_batch_size=max_batch_size,\n            num_hidden_layers=self.num_hidden_layers,\n            max_history=max_history,\n            init_values=init_values,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"state\": nn.spec.Object(object_type=RNNState),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"state\": nn.spec.Object(object_type=RNNState),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"state\": nn.spec.Object(object_type=RNNState),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"state\": nn.spec.Object(object_type=RNNState),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"state\": nn.spec.Object(object_type=RNNState),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_rnn_state\": {\n                \"max_batch_size\": int,\n                \"max_history\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/rwkv6/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/rwkv6/rwkv6_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's RWKV6 parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nfrom ...loader import ExternMapping\nfrom ...quantization import Quantization\nfrom .rwkv6_model import RWKV6_ForCausalLM, RWKV6Config\n\n\ndef huggingface(model_config: RWKV6Config, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : RWKV6Config\n        The configuration of the RWKV6 model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = RWKV6_ForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params = model.export_tvm(  # pylint: disable=unbalanced-tuple-unpacking\n        spec=model.get_default_spec()\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    for i in range(model_config.num_hidden_layers):\n        # rescale\n        if model_config.rescale_every > 0:\n            for name in [\"feed_forward.value.weight\", \"attention.output.weight\"]:\n                mlc_name = f\"model.blocks.{i}.{name}\"\n                hf_name = f\"rwkv.blocks.{i}.{name}\"\n                mlc_param = named_parameters[mlc_name]\n\n                mapping.add_mapping(\n                    mlc_name,\n                    [hf_name],\n                    functools.partial(\n                        lambda x, dtype, t: x.astype(dtype) / (2**t),\n                        dtype=mlc_param.dtype,\n                        t=i // model_config.rescale_every,\n                    ),\n                )\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            hf_name = mlc_name.replace(\"model\", \"rwkv\")\n            mapping.add_mapping(\n                mlc_name,\n                [hf_name],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/rwkv6/rwkv6_model.py",
    "content": "\"\"\"Implementation for RWKV6 architecture.\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional, Tuple\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Object, Tensor, op\nfrom tvm.script import tir as T\n\nfrom mlc_llm.nn.rnn_state import RNNState\nfrom mlc_llm.support import logging\nfrom mlc_llm.support.config import ConfigBase\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass StateID:\n    \"\"\"State ID for RWKV6.\"\"\"\n\n    ATT_X = 0\n    ATT_KV = 1\n    FFN_X = 2\n\n\n@dataclasses.dataclass\nclass RWKV6Config(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the RWKV6 model.\"\"\"\n\n    hidden_size: int\n    intermediate_size: int\n    num_hidden_layers: int\n    vocab_size: int\n    model_version: str = \"6_0\"\n    tensor_parallel_shards: int = 1\n    rescale_every: int = 0\n    head_size: int = 64\n    layer_norm_epsilon: float = 1e-5\n    context_window_size: int = -1  # RWKV does not have context window limitation.\n    prefill_chunk_size: int = 4096\n    num_heads: int = 0\n    max_batch_size: int = 1\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.model_version != \"6_0\":\n            raise ValueError(f\"Only support RWKV v6_0, got {self.model_version}.\")\n        self.intermediate_size = self.intermediate_size or int((self.hidden_size * 3.5)) // 32 * 32\n        self.num_heads = (\n            self.hidden_size // self.head_size if self.num_heads == 0 else self.num_heads\n        )\n        if self.num_heads * self.head_size != self.hidden_size:\n            raise ValueError(\n                f\"hidden_size ({self.hidden_size}) must be divisible \"\n                f\"by head_size ({self.head_size})\"\n            )\n        if self.tensor_parallel_shards != 1:\n            raise ValueError(\"Only support single device at this moment.\")\n\n\n# pylint: disable=invalid-name, missing-docstring\n# pylint: disable=too-many-arguments, too-many-locals, redefined-argument-from-local\ndef create_wkv6_func(\n    num_heads: int,\n    head_size: int,\n    dtype: str,\n    out_dtype: str,\n    state_dtype: str,\n):\n    @T.prim_func\n    def wkv_func(\n        r: T.handle,\n        k: T.handle,\n        v: T.handle,\n        time_faaaa: T.handle,\n        w: T.handle,\n        state: T.handle,\n        out: T.handle,\n        out_state: T.handle,\n    ):\n        T.func_attr({\"op_pattern\": 8, \"tir.noalias\": True, \"tir.is_scheduled\": 1})\n        batch_size, seq_len = T.int64(), T.int64()\n        # Inputs\n        r_buf = T.match_buffer(r, (batch_size, seq_len, num_heads, head_size), dtype=dtype)\n        k_buf = T.match_buffer(k, (batch_size, seq_len, num_heads, head_size), dtype=dtype)\n        v_buf = T.match_buffer(v, (batch_size, seq_len, num_heads, head_size), dtype=dtype)\n        time_faaaa_buf = T.match_buffer(time_faaaa, (num_heads, head_size), dtype=\"float32\")\n        w_buf = T.match_buffer(w, (batch_size, seq_len, num_heads, head_size), dtype=\"float32\")\n        state_buf = T.match_buffer(\n            state, (batch_size, num_heads, head_size, head_size), dtype=state_dtype\n        )\n        # Outputs\n        out_buf = T.match_buffer(out, (batch_size, seq_len, num_heads, head_size), dtype=out_dtype)\n        out_state_buf = T.match_buffer(\n            out_state, (batch_size, num_heads, head_size, head_size), dtype=state_dtype\n        )\n        for b in T.thread_binding(batch_size, thread=\"blockIdx.y\"):\n            for h in T.thread_binding(num_heads, thread=\"blockIdx.x\"):\n                for i in T.thread_binding(head_size, thread=\"threadIdx.x\"):\n                    for j in range(head_size):\n                        with T.sblock(\"init_state\"):\n                            vb, vh, vi, vj = T.axis.remap(\"SSSS\", [b, h, i, j])\n                            out_state_buf[vb, vh, vi, vj] = state_buf[vb, vh, vi, vj]\n\n                    for t in range(seq_len):\n                        with T.sblock(\"comput\"):\n                            vb = T.axis.spatial(batch_size, b)\n                            vt = T.axis.opaque(seq_len, t)\n                            vh = T.axis.spatial(num_heads, h)\n                            vi = T.axis.spatial(head_size, i)\n                            out_buf[vb, vt, vh, vi] = 0\n\n                            for k in range(head_size):\n                                at = k_buf[vb, vt, vh, k] * v_buf[vb, vt, vh, vi]\n                                out_buf[vb, vt, vh, vi] += T.cast(\n                                    r_buf[vb, vt, vh, k], out_dtype\n                                ) * T.cast(\n                                    time_faaaa_buf[vh, k] * at + out_state_buf[vb, vh, vi, k],\n                                    out_dtype,\n                                )\n                                out_state_buf[vb, vh, vi, k] = (\n                                    at + w_buf[vb, vt, vh, k] * out_state_buf[vb, vh, vi, k]\n                                )\n\n    return wkv_func\n\n\ndef token_shift(state: Tensor, x: Tensor):\n    def _te_token_shift(state: te.Tensor, x: te.Tensor):\n        return te.compute(\n            x.shape,\n            lambda b, i, j: tir.if_then_else(i == 0, state[b, j], x[b, i - 1, j]),\n        )\n\n    return op.tensor_expr_op(_te_token_shift, \"token_shift\", [state, x])\n\n\ndef last_token(x: Tensor):\n    batch, seq_len, hidden_size = x.shape\n\n    def _te_last_token(x: te.Tensor):\n        return te.compute((batch, 1, hidden_size), lambda b, _, j: x[b, x.shape[1] - 1, j])\n\n    return x if seq_len == 1 else op.tensor_expr_op(_te_last_token, \"last_token\", [x])\n\n\ndef unbind_to_five(x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:\n    assert x.shape[0] == 5\n\n    def _te_get_ith(x: te.Tensor, i: int):\n        return te.compute((1, *x.shape[1:]), lambda _, j, k, l: x[i, j, k, l])\n\n    return (\n        op.reshape(op.tensor_expr_op(_te_get_ith, \"unbind_to_five\", [x, 0]), x.shape[1:]),\n        op.reshape(op.tensor_expr_op(_te_get_ith, \"unbind_to_five\", [x, 1]), x.shape[1:]),\n        op.reshape(op.tensor_expr_op(_te_get_ith, \"unbind_to_five\", [x, 2]), x.shape[1:]),\n        op.reshape(op.tensor_expr_op(_te_get_ith, \"unbind_to_five\", [x, 3]), x.shape[1:]),\n        op.reshape(op.tensor_expr_op(_te_get_ith, \"unbind_to_five\", [x, 4]), x.shape[1:]),\n    )\n\n\nclass RWKV6_FNN(nn.Module):\n    def __init__(self, config: RWKV6Config, layer_id: int):\n        super().__init__()\n        self.time_maa_k = nn.Parameter((1, 1, config.hidden_size))\n        self.time_maa_r = nn.Parameter((1, 1, config.hidden_size))\n        self.key = nn.Linear(config.hidden_size, config.hidden_size // 2 * 7, bias=False)\n        self.receptance = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n        self.value = nn.Linear(config.hidden_size // 2 * 7, config.hidden_size, bias=False)\n        self.layer_id = layer_id\n\n    def forward(self, x: Tensor, state: RNNState):\n        batch, _, hidden_size = x.shape\n        state_x = state.get(self.layer_id, StateID.FFN_X, (batch, hidden_size), x.dtype)\n        state_x = token_shift(state_x, x)\n\n        state_x = state_x - x\n        xk = x + state_x * self.time_maa_k\n        xr = x + state_x * self.time_maa_r\n\n        last_x = last_token(x).reshape(batch, hidden_size)\n        state = state.set(self.layer_id, StateID.FFN_X, last_x)\n\n        r = op.sigmoid(self.receptance(xr))\n        xv = op.square(op.relu(self.key(xk)))\n        return r * self.value(xv), state\n\n\nclass RWKV6_Attention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Attention layer for RWKV.\"\"\"\n\n    def __init__(self, config: RWKV6Config, layer_id: int):\n        super().__init__()\n        self.time_maa_x = nn.Parameter((1, 1, config.hidden_size))\n        self.time_maa_w = nn.Parameter((1, 1, config.hidden_size))\n        self.time_maa_k = nn.Parameter((1, 1, config.hidden_size))\n        self.time_maa_v = nn.Parameter((1, 1, config.hidden_size))\n        self.time_maa_r = nn.Parameter((1, 1, config.hidden_size))\n        self.time_maa_g = nn.Parameter((1, 1, config.hidden_size))\n\n        # RWKV v6 7B/14B\n        if config.hidden_size == 4096:\n            time_mix_extra_dim = 64\n            time_decay_extra_dim = 128\n        else:\n            time_mix_extra_dim = 32\n            time_decay_extra_dim = 64\n\n        self.time_maa_w1 = nn.Parameter((config.hidden_size, 5 * time_mix_extra_dim))\n        self.time_maa_w2 = nn.Parameter((5, time_mix_extra_dim, config.hidden_size))\n        self.time_decay_w1 = nn.Parameter((config.hidden_size, time_decay_extra_dim))\n        self.time_decay_w2 = nn.Parameter((time_decay_extra_dim, config.hidden_size))\n        self.time_decay = nn.Parameter((1, 1, config.hidden_size))\n        self.time_faaaa = nn.Parameter((config.num_heads, config.head_size))\n\n        self.key = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n        self.value = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n        self.receptance = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n        self.gate = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n        self.output = nn.Linear(config.hidden_size, config.hidden_size, bias=False)\n        self.ln_x = nn.GroupNorm(config.num_heads, config.hidden_size)\n        self.hidden_size = config.hidden_size\n        self.head_size = config.head_size\n        self.num_heads = config.num_heads\n        self.layer_id = layer_id\n        self.dtype = \"float32\"\n\n    def forward(self, x: Tensor, state: RNNState):  # pylint: disable=too-many-locals\n        batch, seq_len, hidden_size = x.shape\n        assert hidden_size == self.hidden_size\n        B, T, H, N = (  # pylint: disable=redefined-outer-name\n            batch,\n            seq_len,\n            self.head_size,\n            self.num_heads,\n        )\n        state_x = state.get(self.layer_id, StateID.ATT_X, (batch, self.hidden_size), x.dtype)\n        state_x = token_shift(state_x, x)\n        state_x = state_x - x\n        xxx = x + state_x * self.time_maa_x\n        xxx = op.permute(\n            op.reshape(op.tanh(op.matmul(xxx, self.time_maa_w1)), (B, T, 5, -1)),\n            [0, 2, 1, 3],\n        )\n        xxx = op.permute(\n            op.matmul(xxx, self.time_maa_w2), axes=[1, 0, 2, 3]\n        )  # it's a batch matrix-matrix multiplication\n        mw, mk, mv, mr, mg = unbind_to_five(xxx)\n\n        kv_state = state.get(\n            self.layer_id,\n            StateID.ATT_KV,\n            (batch, self.num_heads, self.head_size, self.head_size),\n            \"float32\",\n        )\n\n        xw = x + state_x * (self.time_maa_w + mw)\n        xk = x + state_x * (self.time_maa_k + mk)\n        xv = x + state_x * (self.time_maa_v + mv)\n        xr = x + state_x * (self.time_maa_r + mr)\n        xg = x + state_x * (self.time_maa_g + mg)\n\n        r = op.reshape(self.receptance(xr), (B, T, N, H))\n        k = op.reshape(self.key(xk), (B, T, N, H))\n        v = op.reshape(self.value(xv), (B, T, N, H))\n        g = op.silu(self.gate(xg))\n\n        w = op.reshape(self.time_decay, (1, N, H)).astype(\"float32\") + op.reshape(\n            op.matmul(op.tanh(op.matmul(xw, self.time_decay_w1)), self.time_decay_w2),\n            (B, T, N, H),\n        ).astype(\"float32\")\n        w = op.exp(op.negative(op.exp(w)))\n        # w = op.reshape(w, [B, T, N, H])\n\n        out, kv_state = op.tensor_ir_op(\n            create_wkv6_func(\n                num_heads=self.num_heads,\n                head_size=self.head_size,\n                dtype=self.dtype,\n                out_dtype=\"float32\",\n                state_dtype=\"float32\",\n            ),\n            \"wkv6\",\n            [r, k, v, self.time_faaaa, w, kv_state],\n            [\n                Tensor.placeholder([B, T, N, H], \"float32\"),\n                Tensor.placeholder([B, N, H, H], \"float32\"),\n            ],\n        )\n\n        last_x = last_token(x).reshape(batch, hidden_size)\n        state = state.set(self.layer_id, StateID.ATT_X, last_x)\n        state = state.set(self.layer_id, StateID.ATT_KV, kv_state)\n        out = op.astype(self.ln_x(op.reshape(out, x.shape), channel_axis=-1, axes=[]), self.dtype)\n        return self.output(out * g), state\n\n    def to(self, dtype: Optional[str] = None):\n        # RWKV uses special dtype, so we need to convert it.\n        if dtype is not None:\n            self.dtype = dtype\n\n        self.time_maa_x.to(dtype)\n        self.time_maa_w.to(dtype)\n        self.time_maa_k.to(dtype)\n        self.time_maa_v.to(dtype)\n        self.time_maa_r.to(dtype)\n        self.time_maa_g.to(dtype)\n        self.time_maa_w1.to(dtype)\n        self.time_maa_w2.to(dtype)\n        self.time_decay_w1.to(dtype)\n        self.time_decay_w2.to(dtype)\n        self.key.to(dtype)\n        self.value.to(dtype)\n        self.receptance.to(dtype)\n        self.gate.to(dtype)\n        self.output.to(dtype)\n\n        # These parameters are necessary to be converted to float32.\n        self.time_decay.to(\"float32\")\n        self.time_faaaa.to(\"float32\")\n        self.ln_x.to(\"float32\")\n\n\nclass RWKV6_Layer(nn.Module):\n    def __init__(self, config: RWKV6Config, layer_id: int):\n        super().__init__()\n        if layer_id == 0:\n            self.pre_ln = nn.LayerNorm(\n                config.hidden_size,\n                eps=config.layer_norm_epsilon,\n            )\n        self.ln1 = nn.LayerNorm(\n            config.hidden_size,\n            eps=config.layer_norm_epsilon,\n        )\n        self.ln2 = nn.LayerNorm(\n            config.hidden_size,\n            eps=config.layer_norm_epsilon,\n        )\n        self.attention = RWKV6_Attention(config, layer_id)\n        self.feed_forward = RWKV6_FNN(config, layer_id)\n        self.layer_id = layer_id\n        self.rescale_every = config.rescale_every\n\n    def forward(self, x: Tensor, state: RNNState) -> Tensor:\n        if self.layer_id == 0:\n            x = self.pre_ln(x)\n        att_x, state = self.attention(self.ln1(x), state)\n        x += att_x\n        ffn_x, state = self.feed_forward(self.ln2(x), state)\n        x += ffn_x\n        if self.rescale_every > 0 and (self.layer_id + 1) % self.rescale_every == 0:\n            x = x / 2.0\n        return x, state\n\n\nclass RWKV6_Model(nn.Module):\n    \"\"\"Exact same as LlamaModel.\"\"\"\n\n    def __init__(self, config: RWKV6Config):\n        super().__init__()\n        self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)\n        self.blocks = nn.ModuleList(\n            [RWKV6_Layer(config, i) for i in range(config.num_hidden_layers)]\n        )\n        self.ln_out = nn.LayerNorm(\n            config.hidden_size,\n            eps=config.layer_norm_epsilon,\n        )\n\n    def forward(self, input_embed: Tensor, state: RNNState):\n        \"\"\"Forward pass of the model, passing through all decoder layers.\"\"\"\n        hidden_states = input_embed\n        for block in self.blocks:\n            hidden_states, state = block(hidden_states, state)\n        return self.ln_out(hidden_states), state\n\n\nclass RWKV6_ForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Same as LlamaForCausalLM, except for the use of sliding window attention.\"\"\"\n\n    def __init__(self, config: RWKV6Config):\n        self.model = RWKV6_Model(config)\n        self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.vocab_size = config.vocab_size\n        self.num_hidden_layers = config.num_hidden_layers\n        self.hidden_size = config.hidden_size\n        self.num_heads = config.num_heads\n        self.head_size = config.head_size\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def embed(self, input_ids: Tensor):\n        return self.model.embeddings(input_ids)\n\n    def forward(\n        self,\n        input_embed: Tensor,\n        state: RNNState,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        \"\"\"Forward pass.\"\"\"\n        hidden_states, state = self.model(input_embed, state)\n        hidden_states = last_token(hidden_states)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, state\n\n    def prefill(self, input_embed: Tensor, state: RNNState):\n        \"\"\"Prefilling the prompt.\"\"\"\n        return self.forward(input_embed, state)\n\n    def decode(self, input_embed: Tensor, state: RNNState):\n        \"\"\"Decoding step.\"\"\"\n        return self.forward(input_embed, state)\n\n    def batch_prefill(self, input_embeds: Tensor, logit_positions: Tensor, state: RNNState):\n        \"\"\"Prefilling the prompt.\"\"\"\n        return self.forward(input_embeds, state, logit_positions=logit_positions)\n\n    def batch_decode(self, input_embeds: Tensor, state: RNNState):\n        \"\"\"Decoding step.\"\"\"\n        return self.forward(input_embeds, state)\n\n    def batch_verify(self, input_embeds: Tensor, state: RNNState):\n        \"\"\"Verify step.\"\"\"\n        return self.forward(input_embeds, state)\n\n    def create_rnn_state(\n        self,\n        max_batch_size: tir.Var,\n        max_history: tir.Var,\n    ) -> Object:\n        \"\"\"Create RNN state.\"\"\"\n        init_values = [\n            op.zeros((self.hidden_size,), dtype=self.dtype),  # ATT_X\n            op.zeros((self.num_heads, self.head_size, self.head_size), dtype=\"float32\"),  # ATT_KV\n            op.zeros((self.hidden_size,), dtype=self.dtype),  # FFN_X\n        ]\n        return RNNState.create(\n            max_batch_size=max_batch_size,\n            num_hidden_layers=self.num_hidden_layers,\n            max_history=max_history,\n            init_values=init_values,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"state\": nn.spec.Object(object_type=RNNState),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"state\": nn.spec.Object(object_type=RNNState),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"state\": nn.spec.Object(object_type=RNNState),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"state\": nn.spec.Object(object_type=RNNState),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"state\": nn.spec.Object(object_type=RNNState),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_rnn_state\": {\n                \"max_batch_size\": int,\n                \"max_history\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/stable_lm/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/stable_lm/stablelm_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's StableLM parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nfrom mlc_llm.loader.standard_loader import make_standard_hf_loader\n\nfrom .stablelm_model import StableLmForCausalLM\n\nhuggingface = make_standard_hf_loader(\n    model_cls=StableLmForCausalLM,\n    add_qkv_bias=True,\n    qkv_bias_optional=True,\n)\n"
  },
  {
    "path": "python/mlc_llm/model/stable_lm/stablelm_model.py",
    "content": "\"\"\"\nImplementation for StableLM architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass StableLmConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the StableLM model.\"\"\"\n\n    vocab_size: int\n    hidden_size: int\n    num_hidden_layers: int\n    num_attention_heads: int\n    num_key_value_heads: int\n    layer_norm_eps: float\n    partial_rotary_factor: float\n    rope_theta: int\n    intermediate_size: int\n    use_qkv_bias: bool = False  # Default to False for Stable-LM 3B model\n    head_dim: int = 0\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    max_batch_size: int = 1\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.head_dim * self.num_attention_heads == self.hidden_size\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass StableLmAttention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: StableLmConfig):\n        self.hidden_size = config.hidden_size\n        self.rope_theta = config.rope_theta\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.head_dim = config.head_dim\n        if config.num_key_value_heads % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.num_key_value_heads} key-value attention heads \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.num_heads = config.num_attention_heads // self.tensor_parallel_shards\n        self.num_key_value_heads = config.num_key_value_heads // self.tensor_parallel_shards\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.rotary_ndims = int(config.partial_rotary_factor * self.head_dim)\n\n        self.qkv_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=(self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,\n            bias=config.use_qkv_bias,\n        )\n        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q, h_kv = self.head_dim, self.num_heads, self.num_key_value_heads\n        b, s, _ = hidden_states.shape\n        qkv = self.qkv_proj(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        attn_output = self.o_proj(output)\n        return attn_output\n\n\nclass StableLmMLP(nn.Module):\n    def __init__(self, config: StableLmConfig):\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        self.gate_up_proj = nn.Linear(\n            in_features=config.hidden_size,\n            out_features=2 * self.intermediate_size,\n            bias=False,\n        )\n        self.down_proj = nn.Linear(self.intermediate_size, config.hidden_size, bias=False)\n\n    def forward(self, x: Tensor):\n        concat_x1_x2 = self.gate_up_proj(x)\n        x1, x2 = op.split(concat_x1_x2, 2, axis=-1)\n        return self.down_proj(op.silu(x1) * x2)\n\n\nclass StableLmDecoderLayer(nn.Module):\n    def __init__(self, config: StableLmConfig):\n        norm_eps = config.layer_norm_eps\n        self.self_attn = StableLmAttention(config)\n        self.mlp = StableLmMLP(config)\n        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)\n        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_heads * hd\n            k = self.self_attn.num_key_value_heads * hd\n            v = self.self_attn.num_key_value_heads * hd\n            i = self.mlp.intermediate_size\n            _set(\n                self.self_attn.qkv_proj.weight,\n                tp.ShardSingleDim(\"_shard_qkv_weight\", dim=0, segs=[q, k, v]),\n            )\n            if config.use_qkv_bias:\n                _set(\n                    self.self_attn.qkv_proj.bias,\n                    tp.ShardSingleDim(\"_shard_qkv_bias\", dim=0, segs=[q, k, v]),\n                )\n            _set(self.self_attn.o_proj.weight, tp.ShardSingleDim(\"_shard_o\", dim=1))\n            _set(\n                self.mlp.gate_up_proj.weight,\n                tp.ShardSingleDim(\"_shard_mlp_up\", segs=[i, i], dim=0),\n            )\n            _set(self.mlp.down_proj.weight, tp.ShardSingleDim(\"_shard_mlp_down\", dim=1))\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.mlp(self.post_attention_layernorm(hidden_states))\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass StableLmModel(nn.Module):\n    def __init__(self, config: StableLmConfig):\n        assert config.hidden_size % config.num_attention_heads == 0\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)\n        self.layers = nn.ModuleList(\n            [StableLmDecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n\n    def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = inputs\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass StableLmForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: StableLmConfig):\n        self.model = StableLmModel(config)\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.vocab_size = config.vocab_size\n        self.dtype = \"float32\"\n        self.num_hidden_layers = config.num_hidden_layers\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.head_dim = config.head_dim\n        self.vocab_size = config.vocab_size\n        self.rope_theta = config.rope_theta\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.partial_rotary_factor = config.partial_rotary_factor\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n            rotary_dim=int(self.head_dim * self.partial_rotary_factor),\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/starcoder2/__init__.py",
    "content": ""
  },
  {
    "path": "python/mlc_llm/model/starcoder2/starcoder2_loader.py",
    "content": "\"\"\"\nThis file specifies how MLC's Starcoder2 parameter maps from other formats, for example HuggingFace\nPyTorch, HuggingFace safetensors.\n\"\"\"\n\nimport functools\n\nimport numpy as np\n\nfrom mlc_llm.loader import ExternMapping\nfrom mlc_llm.quantization import Quantization\n\nfrom .starcoder2_model import Starcoder2Config, Starcoder2ForCausalLM\n\n\ndef huggingface(model_config: Starcoder2Config, quantization: Quantization) -> ExternMapping:\n    \"\"\"Returns a parameter mapping that maps from the names of MLC LLM parameters to\n    the names of HuggingFace PyTorch parameters.\n\n    Parameters\n    ----------\n    model_config : InternLMConfig\n        The configuration of the InternLM model.\n\n    quantization : Quantization\n        The quantization configuration.\n\n    Returns\n    -------\n    param_map : ExternMapping\n        The parameter mapping from MLC to HuggingFace PyTorch.\n    \"\"\"\n    model = Starcoder2ForCausalLM(model_config)\n    if quantization is not None:\n        model.to(quantization.model_dtype)\n    _, _named_params, _ = model.export_tvm(  # type: ignore[misc]\n        spec=model.get_default_spec(),\n        allow_extern=True,\n    )\n    named_parameters = dict(_named_params)\n\n    mapping = ExternMapping()\n\n    mlc_name = \"lm_head.weight\"\n    mlc_param = named_parameters[mlc_name]\n    mapping.add_mapping(\n        mlc_name,\n        [\"model.embed_tokens.weight\"],\n        functools.partial(\n            lambda x, dtype: x.astype(dtype),\n            dtype=mlc_param.dtype,\n        ),\n    )\n\n    for i in range(model_config.num_hidden_layers):\n        # Add QKV in self attention\n        attn = f\"model.layers.{i}.self_attn\"\n        mlc_name = f\"{attn}.wqkv_pack.weight\"\n        mlc_param = named_parameters[mlc_name]\n        mapping.add_mapping(\n            mlc_name,\n            [\n                f\"{attn}.q_proj.weight\",\n                f\"{attn}.k_proj.weight\",\n                f\"{attn}.v_proj.weight\",\n            ],\n            functools.partial(\n                lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),\n                dtype=mlc_param.dtype,\n            ),\n        )\n        mlc_name = f\"{attn}.wqkv_pack.bias\"\n        if mlc_name in named_parameters:\n            mlc_param = named_parameters[mlc_name]\n            mapping.add_mapping(\n                mlc_name,\n                [\n                    f\"{attn}.q_proj.bias\",\n                    f\"{attn}.k_proj.bias\",\n                    f\"{attn}.v_proj.bias\",\n                ],\n                functools.partial(\n                    lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n        # Add gates in MLP\n\n    for mlc_name, mlc_param in named_parameters.items():\n        if mlc_name not in mapping.param_map:\n            mapping.add_mapping(\n                mlc_name,\n                [mlc_name],\n                functools.partial(\n                    lambda x, dtype: x.astype(dtype),\n                    dtype=mlc_param.dtype,\n                ),\n            )\n    return mapping\n"
  },
  {
    "path": "python/mlc_llm/model/starcoder2/starcoder2_model.py",
    "content": "\"\"\"\nImplementation for Starcoder2 architecture.\n\"\"\"\n\nimport dataclasses\nfrom typing import Any, Dict, Optional\n\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.nn import PagedKVCache, RopeMode\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\nfrom mlc_llm.support.config import ConfigBase\nfrom mlc_llm.support.style import bold\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass Starcoder2Config(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration of the Starcoder2 model.\"\"\"\n\n    vocab_size: int\n    hidden_size: int\n    num_hidden_layers: int\n    num_attention_heads: int\n    num_key_value_heads: int\n    hidden_act: str\n    norm_epsilon: float\n    intermediate_size: int\n    rope_theta: int\n    use_bias: bool\n    use_cache: bool\n    bos_token_id: int\n    eos_token_id: int\n    context_window_size: int = 0\n    prefill_chunk_size: int = 0\n    tensor_parallel_shards: int = 1\n    max_batch_size: int = 1\n    head_dim: int = 0\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n    def __post_init__(self):\n        if self.context_window_size == 0:\n            for name in [\"max_position_embeddings\", \"max_sequence_length\"]:\n                if name in self.kwargs:\n                    self.context_window_size = self.kwargs.pop(name)\n                    logger.info(\n                        \"%s not found in config.json. Falling back to %s (%d)\",\n                        bold(\"context_window_size\"),\n                        bold(name),\n                        self.context_window_size,\n                    )\n                    break\n            else:\n                raise ValueError(\n                    \"Unable to determine the maximum sequence length, because none of \"\n                    \"`context_window_size`, `max_position_embeddings` or `max_sequence_length` is \"\n                    \"provided in `config.json`.\"\n                )\n        if self.head_dim == 0:\n            self.head_dim = self.hidden_size // self.num_attention_heads\n        assert self.head_dim * self.num_attention_heads == self.hidden_size\n        if self.prefill_chunk_size == 0:\n            logger.info(\n                \"%s defaults to %d\",\n                bold(\"prefill_chunk_size\"),\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n        elif self.prefill_chunk_size > self.context_window_size:\n            logger.info(\n                \"Overriding %s from %d to %d\",\n                bold(\"prefill_chunk_size\"),\n                self.prefill_chunk_size,\n                min(self.context_window_size, 8192),\n            )\n            self.prefill_chunk_size = min(self.context_window_size, 8192)\n\n\n# pylint: disable=invalid-name,missing-docstring\n\n\nclass Starcoder2Attention(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: Starcoder2Config):\n        super().__init__()  # Make sure to call the parent class constructor\n        self.hidden_size = config.hidden_size\n        self.rope_theta = config.rope_theta\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        if config.num_attention_heads % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split {config.num_attention_heads} attention heads \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n\n        self.num_heads = config.num_attention_heads // self.tensor_parallel_shards\n        self.head_dim = config.head_dim\n        self.num_key_value_heads = config.num_key_value_heads // self.tensor_parallel_shards\n        self.num_key_value_groups = self.num_heads // self.num_key_value_heads\n        self.max_position_embeddings = config.context_window_size\n        self.use_bias = config.use_bias\n\n        self.wqkv_pack = nn.Linear(\n            in_features=self.hidden_size,\n            out_features=(self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,\n            bias=self.use_bias,\n        )\n        self.o_proj = nn.Linear(\n            self.num_heads * self.head_dim, self.hidden_size, bias=self.use_bias\n        )\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        d, h_q, h_kv = self.head_dim, self.num_heads, self.num_key_value_heads\n        b, s, _ = hidden_states.shape\n        qkv = self.wqkv_pack(hidden_states)\n        qkv = op.reshape(qkv, (b, s, h_q + h_kv + h_kv, d))\n        output = op.reshape(\n            paged_kv_cache.attention_with_fused_qkv(\n                layer_id, qkv, self.num_heads, sm_scale=self.head_dim**-0.5\n            ),\n            (b, s, h_q * d),\n        )\n        attn_output = self.o_proj(output)\n        return attn_output\n\n\nclass Starcoder2MLP(nn.Module):\n    def __init__(self, config: Starcoder2Config):\n        if config.intermediate_size % config.tensor_parallel_shards != 0:\n            raise ValueError(\n                f\"Cannot split MLP intermediate size {config.intermediate_size} \"\n                f\"evenly to {config.tensor_parallel_shards} GPUs.\"\n            )\n        self.intermediate_size = config.intermediate_size // config.tensor_parallel_shards\n        embed_dim = config.hidden_size\n\n        self.c_fc = nn.Linear(\n            in_features=embed_dim,\n            out_features=self.intermediate_size,\n            bias=config.use_bias,\n        )\n        self.c_proj = nn.Linear(self.intermediate_size, embed_dim, bias=config.use_bias)\n\n    def forward(self, hidden_states: Tensor):\n        hidden_states = self.c_fc(hidden_states)\n        hidden_states = op.gelu(hidden_states, approximate=\"tanh\")\n        hidden_states = self.c_proj(hidden_states)\n        return hidden_states\n\n\nclass Starcoder2DecoderLayer(nn.Module):\n    def __init__(self, config: Starcoder2Config):\n        self.hidden_size = config.hidden_size\n        self.self_attn = Starcoder2Attention(config)\n        self.mlp = Starcoder2MLP(config)\n        self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)\n        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)\n\n        def _set_tp():\n            def _set(layer, hint):\n                layer.attrs[\"shard_strategy\"] = hint\n\n            hd = config.head_dim\n            q = self.self_attn.num_heads * hd\n            k = self.self_attn.num_key_value_heads * hd\n            v = self.self_attn.num_key_value_heads * hd\n            _set(\n                self.self_attn.wqkv_pack.weight,\n                tp.ShardSingleDim(\"_shard_qkv_weight\", dim=0, segs=[q, k, v]),\n            )\n            if config.use_bias:\n                _set(\n                    self.self_attn.wqkv_pack.bias,\n                    tp.ShardSingleDim(\"_shard_qkv_bias\", dim=0, segs=[q, k, v]),\n                )\n\n            _set(self.self_attn.o_proj.weight, tp.ShardSingleDim(\"_shard_o\", dim=1))\n\n            _set(\n                self.mlp.c_fc.weight,\n                tp.ShardSingleDim(\"_shard_c_fc_weight\", dim=0),\n            )\n            if config.use_bias:\n                _set(self.mlp.c_fc.bias, tp.ShardSingleDim(\"_shard_c_fc_bias\", dim=0))\n\n            _set(self.mlp.c_proj.weight, tp.ShardSingleDim(\"_shard_mlp_c_proj\", dim=1))\n\n            if config.use_bias:\n                _set(\n                    self.mlp.c_proj.bias,\n                    tp.ShardSingleDim(\"_shard_mlp_c_proj_bias\", dim=0),\n                )\n\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        _set_tp()\n\n    def forward(self, hidden_states: Tensor, paged_kv_cache: PagedKVCache, layer_id: int):\n        out = self.self_attn(self.input_layernorm(hidden_states), paged_kv_cache, layer_id)\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        out = self.mlp(self.post_attention_layernorm(hidden_states))\n        hidden_states = self._apply_residual(out, residual=hidden_states)\n        return hidden_states\n\n    def _apply_residual(self, out, residual):\n        if self.tensor_parallel_shards > 1:\n            return op.ccl_allreduce(out, \"sum\") + residual\n        return out + residual\n\n\nclass Starcoder2Model(nn.Module):\n    def __init__(self, config: Starcoder2Config):\n        assert config.hidden_size % config.num_attention_heads == 0\n        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)\n        self.layers = nn.ModuleList(\n            [Starcoder2DecoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n        self.norm = nn.LayerNorm(config.hidden_size, config.norm_epsilon)\n\n    def forward(self, inputs: Tensor, paged_kv_cache: PagedKVCache):\n        hidden_states = inputs\n        for layer_id, layer in enumerate(self.layers):\n            hidden_states = layer(hidden_states, paged_kv_cache, layer_id)\n        hidden_states = self.norm(hidden_states)\n        return hidden_states\n\n\nclass Starcoder2ForCausalLM(nn.Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: Starcoder2Config):\n        self.model = Starcoder2Model(config)\n        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)\n        self.vocab_size = config.vocab_size\n        self.num_hidden_layers = config.num_hidden_layers\n        self.hidden_size = config.hidden_size\n        self.num_attention_heads = config.num_attention_heads\n        self.num_key_value_heads = config.num_key_value_heads\n        self.head_dim = config.head_dim\n        self.vocab_size = config.vocab_size\n        self.rope_theta = config.rope_theta\n        self.tensor_parallel_shards = config.tensor_parallel_shards\n        self.dtype = \"float32\"\n\n    def to(self, dtype: Optional[str] = None):\n        super().to(dtype=dtype)\n        if dtype is not None:\n            self.dtype = dtype\n\n    def batch_forward(\n        self,\n        input_embeds: Tensor,\n        paged_kv_cache: PagedKVCache,\n        logit_positions: Optional[Tensor] = None,\n    ):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embeds, paged_kv_cache)\n        if logit_positions is not None:\n            hidden_states = op.take(hidden_states, logit_positions, axis=1)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits\n\n    def embed(self, input_ids: Tensor):\n        if self.tensor_parallel_shards > 1:\n            input_ids = op.ccl_broadcast_from_worker0(input_ids)\n        return self.model.embed_tokens(input_ids)\n\n    def prefill(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        def _index(x: te.Tensor):  # x[:-1,:]\n            b, s, d = x.shape\n            return te.compute((b, 1, d), lambda i, _, k: x[i, s - 1, k], name=\"index\")\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        hidden_states = op.tensor_expr_op(_index, name_hint=\"index\", args=[hidden_states])\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def decode(self, input_embed: Tensor, paged_kv_cache: PagedKVCache):\n        op_ext.configure()\n\n        hidden_states = self.model(input_embed, paged_kv_cache)\n        logits = self.lm_head(hidden_states)\n        if logits.dtype != \"float32\":\n            logits = logits.astype(\"float32\")\n        return logits, paged_kv_cache\n\n    def batch_prefill(\n        self,\n        input_embeds: Tensor,\n        logit_positions: Tensor,\n        paged_kv_cache: PagedKVCache,\n    ):\n        if self.tensor_parallel_shards > 1:\n            logit_positions = op.ccl_broadcast_from_worker0(logit_positions)\n        logits = self.batch_forward(input_embeds, paged_kv_cache, logit_positions)\n        return logits, paged_kv_cache\n\n    def batch_decode(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def batch_verify(self, input_embeds: Tensor, paged_kv_cache: PagedKVCache):\n        logits = self.batch_forward(input_embeds, paged_kv_cache)\n        return logits, paged_kv_cache\n\n    def create_paged_kv_cache(  # pylint: disable=too-many-arguments\n        self,\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n    ) -> PagedKVCache:\n        return PagedKVCache.create_generic(\n            attn_kind=\"mha\",\n            max_batch_size=max_batch_size,\n            max_total_seq_len=max_total_seq_len,\n            prefill_chunk_size=prefill_chunk_size,\n            page_size=page_size,\n            support_sliding_window=support_sliding_window,\n            num_hidden_layers=self.num_hidden_layers,\n            num_attention_heads=self.num_attention_heads // self.tensor_parallel_shards,\n            num_key_value_heads=self.num_key_value_heads // self.tensor_parallel_shards,\n            qk_head_dim=self.head_dim,\n            v_head_dim=self.head_dim,\n            rope_mode=RopeMode.NORMAL,\n            rope_scale=1,\n            rope_theta=self.rope_theta,\n            dtype=self.dtype,\n        )\n\n    def get_default_spec(self):\n        mod_spec = {\n            \"embed\": {\n                \"input_ids\": nn.spec.Tensor([\"seq_len\"], \"int32\"),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"prefill\": {\n                \"input_embed\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"decode\": {\n                \"input_embed\": nn.spec.Tensor([1, 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_prefill\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"logit_positions\": nn.spec.Tensor([\"batch_size\"], \"int32\"),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_decode\": {\n                \"input_embeds\": nn.spec.Tensor([\"batch_size\", 1, self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"batch_verify\": {\n                \"input_embeds\": nn.spec.Tensor([1, \"seq_len\", self.hidden_size], self.dtype),\n                \"paged_kv_cache\": nn.spec.Object(object_type=PagedKVCache),\n                \"$\": {\n                    \"param_mode\": \"packed\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n                \"$\": {\n                    \"param_mode\": \"none\",\n                    \"effect_mode\": \"none\",\n                },\n            },\n        }\n        return nn.spec.ModuleSpec.from_raw(mod_spec, self)\n"
  },
  {
    "path": "python/mlc_llm/model/vision/__init__.py",
    "content": "\"\"\"Common `nn.Modules` used to define LLMs in this project.\"\"\"\n\nfrom .clip_vision import CLIPVisionConfig, CLIPVisionModel\nfrom .image_processing import ImageProcessor\n"
  },
  {
    "path": "python/mlc_llm/model/vision/clip_vision.py",
    "content": "\"\"\"\nImplements the CLIP Vision Encoder.\n\"\"\"\n\nimport dataclasses\nimport logging\nfrom typing import Any, Dict, Tuple\n\nfrom tvm import relax\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Module, Tensor\nfrom tvm.relax.frontend.nn.modules import Conv2D\nfrom tvm.relax.frontend.nn.op import (\n    add,\n    broadcast_to,\n    concat,\n    permute_dims,\n    reshape,\n    wrap_nested,\n)\nfrom tvm.relax.op import arange\n\nfrom mlc_llm import op as op_ext\nfrom mlc_llm.support.config import ConfigBase\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclasses.dataclass\nclass CLIPVisionConfig(ConfigBase):  # pylint: disable=too-many-instance-attributes\n    \"\"\"\n    Config for the vision encoder\n    \"\"\"\n\n    hidden_size: int\n    image_size: int\n    intermediate_size: int\n    num_attention_heads: int\n    num_hidden_layers: int\n    patch_size: int\n    projection_dim: int\n    vocab_size: int\n    num_channels: int = 3\n    layer_norm_eps: float = 1e-06\n    kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)\n\n\n# pylint: disable=invalid-name,missing-docstring\nclass CLIPVisionEmbeddings(Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: CLIPVisionConfig):\n        super().__init__()\n        self.config = config\n        self.embed_dim = config.hidden_size\n        self.image_size = config.image_size\n        self.patch_size = config.patch_size\n        self.class_embedding = nn.Parameter((self.embed_dim,))\n        self.patch_embedding = Conv2D(\n            in_channels=config.num_channels,\n            out_channels=self.embed_dim,\n            kernel_size=self.patch_size,\n            stride=self.patch_size,\n            bias=False,\n        )\n\n        self.num_patches = (self.image_size // self.patch_size) ** 2\n        self.num_positions = self.num_patches + 1\n        self.position_embedding = nn.Embedding(num=self.num_positions, dim=self.embed_dim)\n\n    def forward(self, pixel_values: Tensor) -> Tensor:\n        batch_size = pixel_values.shape[0]\n        patch_embeds = self.patch_embedding(pixel_values)  # shape = [*, width, grid, grid]\n        patch_embeds = reshape(patch_embeds, shape=(batch_size, self.embed_dim, -1))\n        patch_embeds = permute_dims(\n            patch_embeds, axes=(0, 2, 1)\n        )  # shape = [batch,grid*grid,embed_dim]\n        class_embeds = broadcast_to(\n            self.class_embedding, shape=(batch_size, 1, self.embed_dim)\n        )  # shape of (batch,1,embed_dim)\n        embeddings = concat([class_embeds, patch_embeds], dim=1)\n\n        posi_ids = reshape(\n            wrap_nested(arange(0, self.num_positions, dtype=\"int32\"), name=\"arange\"),\n            shape=(1, -1),\n        )\n        batch_position_embedding = broadcast_to(\n            self.position_embedding(posi_ids),\n            shape=(batch_size, self.num_positions, self.embed_dim),\n        )\n        embeddings = add(embeddings, batch_position_embedding)\n        return embeddings\n\n\n# pylint: disable=missing-docstring\ndef sigmoid(x: Tensor, name: str = \"sigmoid\") -> Tensor:\n    \"\"\"Sigmoid of a Tensor\n\n    Parameters\n    ----------\n    x : Tensor\n        Input tensor to expand.\n    name : str\n        Name hint for this operator.\n\n    Returns\n    -------\n    result : Tensor\n        Sigmoid result.\n    \"\"\"\n    return wrap_nested(relax.op.sigmoid(x._expr), name)  # pylint: disable=protected-access\n\n\nclass QuickGELU(Module):\n    def forward(self, input_tensor: Tensor) -> Tensor:\n        return input_tensor * sigmoid(input_tensor * 1.702)\n\n\nclass CLIPMLP(Module):\n    def __init__(self, config: CLIPVisionConfig):\n        super().__init__()\n        self.activation_fn = QuickGELU()\n        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)\n        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)\n\n    def forward(self, hidden_states: Tensor) -> Tensor:\n        hidden_states = self.fc1(hidden_states)\n        hidden_states = self.activation_fn(hidden_states)\n        hidden_states = self.fc2(hidden_states)\n        return hidden_states\n\n\nclass CLIPAttention(Module):  # pylint: disable=too-many-instance-attributes\n    def __init__(self, config: CLIPVisionConfig):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.num_heads = config.num_attention_heads\n        self.head_dim = self.embed_dim // self.num_heads\n        if (self.head_dim * self.num_heads) != self.embed_dim:\n            raise ValueError(\n                f\"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}\"\n                f\" and `num_heads`: {self.num_heads}).\"\n            )\n        self.scale = self.head_dim**-0.5\n        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)\n        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)\n        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)\n        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)\n\n    def forward(\n        self,\n        hidden_states: Tensor,\n    ) -> Tensor:\n        d, h = self.head_dim, self.num_heads\n        b, s, _ = hidden_states.shape  # batch_size, seq_len, embed_dim\n\n        q = self.q_proj(hidden_states).reshape(b, s, h, d)\n        k = self.k_proj(hidden_states).reshape(b, s, h, d)\n        v = self.v_proj(hidden_states).reshape(b, s, h, d)\n\n        attn_output = op_ext.attention(q, k, v, None)\n        attn_output = self.out_proj(attn_output)\n        return attn_output\n\n\nclass CLIPEncoderLayer(Module):\n    def __init__(self, config: CLIPVisionConfig):\n        super().__init__()\n        self.embed_dim = config.hidden_size\n        self.self_attn = CLIPAttention(config)\n        self.layer_norm1 = nn.LayerNorm(normalized_shape=self.embed_dim, eps=config.layer_norm_eps)\n        self.mlp = CLIPMLP(config)\n        self.layer_norm2 = nn.LayerNorm(normalized_shape=self.embed_dim, eps=config.layer_norm_eps)\n\n    def forward(self, hidden_states: Tensor) -> Tensor:\n        residual = hidden_states\n        hidden_states = self.layer_norm1(hidden_states)\n        hidden_states = self.self_attn(hidden_states=hidden_states)\n        hidden_states = residual + hidden_states\n        residual = hidden_states\n        hidden_states = self.layer_norm2(hidden_states)\n        hidden_states = self.mlp(hidden_states)\n        hidden_states = residual + hidden_states\n\n        outputs = (hidden_states,)\n        return outputs\n\n\nclass CLIPEncoder(Module):\n    def __init__(self, config: CLIPVisionConfig):\n        super().__init__()\n        self.layers = nn.ModuleList(\n            [CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]\n        )\n\n    def forward(self, inputs_embeds: Tensor) -> Tensor:\n        hidden_states = inputs_embeds\n        encoder_states: Tuple[Any, ...] = ()\n        for _, encoder_layer in enumerate(self.layers):\n            encoder_states = encoder_states + (hidden_states,)\n            layer_outputs = encoder_layer(hidden_states)\n            hidden_states = layer_outputs[0]\n        encoder_states = encoder_states + (hidden_states,)\n        return encoder_states\n\n\nclass CLIPVisionTransformer(Module):\n    def __init__(self, config: CLIPVisionConfig):\n        super().__init__()\n        embed_dim = config.hidden_size\n        self.embeddings = CLIPVisionEmbeddings(config)\n        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)\n        self.encoder = CLIPEncoder(config)\n        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)\n\n    def forward(self, pixel_values: Tensor) -> Tensor:\n        hidden_states = self.embeddings(pixel_values)\n        hidden_states = self.pre_layrnorm(hidden_states)\n        encoder_outputs = self.encoder(inputs_embeds=hidden_states)\n        return encoder_outputs\n\n\nclass CLIPVisionModel(Module):\n    no_quantization: bool = True\n\n    def __init__(self, config: CLIPVisionConfig):\n        super().__init__()\n        self.vision_model = CLIPVisionTransformer(config)\n\n    def forward(self, pixel_values: Tensor) -> Tensor:\n        return self.vision_model(pixel_values)[-2]\n"
  },
  {
    "path": "python/mlc_llm/model/vision/image_processing.py",
    "content": "\"\"\"\nImplements the CLIP Image processor.\n\"\"\"\n\nfrom tvm import s_tir, tir\nfrom tvm.relax.frontend.nn import Module, Tensor, op\nfrom tvm.script import tir as T\n\n\ndef _var(dtype, size=1):\n    return T.sblock_alloc_buffer((size,), dtype, scope=\"local\")\n\n\n# pylint: disable=invalid-name,missing-docstring,no-else-return,too-many-locals,useless-parent-delegation\nclass ImageProcessor(Module):\n    def __init__(self):\n        super().__init__()\n\n    # pylint: disable=dangerous-default-value\n    def apply_schedule(self, sch, block, bdx=32, tile=[32, 32]):\n        loop_x, loop_y = sch.get_loops(block)[-2:]\n        xo, xi = sch.split(loop_x, factors=[tile[0], None])\n        yo, yi = sch.split(loop_y, factors=[tile[1], None])\n        sch.reorder(xo, yo, xi, yi)\n        t = sch.fuse(xo, yo)\n        ty, tx = sch.split(t, factors=[None, bdx])\n        sch.bind(ty, \"threadIdx.y\")\n        sch.bind(tx, \"threadIdx.x\")\n\n    def resize(self, image: Tensor, params):  # image layout:NCHW\n        assert 4 == image.ndim, \"image should be 4D data tensor\"\n        assert 3 == image.shape[1], \"image layout should be NCHW\"\n\n        def get_output_image_size(image: Tensor):\n            h = image.shape[2]\n            w = image.shape[3]\n\n            if \"height\" in params and \"width\" in params:\n                return (params[\"height\"], params[\"width\"])\n            elif \"shortest_edge\" in params:\n                short = tir.Select(w < h, w, h)\n                long = tir.Select(w > h, w, h)\n                requested_new_short = params[\"shortest_edge\"]\n                new_short, new_long = (\n                    tir.generic.cast(requested_new_short, \"int64\"),\n                    tir.generic.cast(\n                        requested_new_short\n                        * tir.div(\n                            tir.generic.cast(long, \"float32\"),\n                            tir.generic.cast(short, \"float32\"),\n                        ),\n                        \"int64\",\n                    ),\n                )\n                ret_h = tir.Select(w <= h, new_long, new_short)\n                ret_w = tir.Select(w <= h, new_short, new_long)\n                return (ret_h, ret_w)\n            elif \"hd_transform\" in params:\n                hd_num = 4 if \"hd_num\" not in params else params[\"hd_num\"]\n                pad_num = 336 if \"pad_num\" not in params else params[\"pad_num\"]\n                ratio = tir.Select(\n                    w > h,\n                    tir.div(tir.generic.cast(w, \"float32\"), tir.generic.cast(h, \"float32\")),\n                    tir.div(tir.generic.cast(h, \"float32\"), tir.generic.cast(w, \"float32\")),\n                )\n\n                scale = tir.ceil(tir.sqrt(tir.generic.cast(hd_num, \"float32\") * ratio))\n\n                scale = tir.Select(\n                    (scale * tir.ceil(tir.div(scale, ratio))) > hd_num,\n                    scale - 1,\n                    scale,\n                )\n                scale = tir.generic.cast(scale, \"int64\")\n\n                new_w = tir.Select(\n                    w >= h,\n                    scale * pad_num,\n                    tir.generic.cast(tir.div(scale * pad_num, ratio), \"int64\"),\n                )\n                new_h = tir.Select(\n                    w >= h,\n                    tir.generic.cast(tir.div(new_w, ratio), \"int64\"),\n                    scale * pad_num,\n                )\n                return (new_h, new_w)\n            else:\n                assert False, \"not supported resize parameter\"\n\n        new_h, new_w = get_output_image_size(image)\n        out = op.interpolate(image, (new_h, new_w), data_layout=\"NCHW\", mode=\"linear\")\n        return out\n\n    # pylint: disable=too-many-arguments,too-many-locals\n    def crop(self, image: Tensor, crop_size):\n        assert 4 == image.ndim, \"image should be 4D data tensor\"\n        assert 3 == image.shape[1], \"image layout should be NCHW\"\n\n        def create_crop_func(dtype):  # , top, bottom, left, right):\n            @T.prim_func\n            def crop_func(\n                image: T.handle,\n                out: T.handle,\n                top: T.int64(),\n                bottom: T.int64(),\n                left: T.int64(),\n                right: T.int64(),\n            ):\n                T.func_attr({\"op_pattern\": 8, \"tir.noalias\": True, \"tir.is_scheduled\": 1})\n                n, c, h, w = T.int64(), T.int64(), T.int64(), T.int64()\n                image_buf = T.match_buffer(image, (n, c, h, w), dtype=dtype)\n                out_buf = T.match_buffer(out, (n, c, bottom - top, right - left), dtype=dtype)\n                out_h = bottom - top\n                out_w = right - left\n                for n_idx in T.thread_binding(n, thread=\"blockIdx.x\"):\n                    for c_idx in T.thread_binding(c, thread=\"blockIdx.y\"):\n                        for h_idx, w_idx in T.grid(out_h, out_w):\n                            with T.sblock(\"crop\"):\n                                if (h_idx + T.int64(top)) < h and (w_idx + T.int64(left)) < w:\n                                    T.writes(out_buf[n_idx, c_idx, h_idx, w_idx])\n                                    T.reads(image_buf[n_idx, c_idx, h_idx + top, w_idx + left])\n                                    out_buf[n_idx, c_idx, h_idx, w_idx] = image_buf[\n                                        n_idx, c_idx, h_idx + top, w_idx + left\n                                    ]\n\n            sch = s_tir.Schedule(crop_func)\n            self.apply_schedule(sch, sch.get_sblock(\"crop\"))\n            return sch.mod[\"main\"].with_attr(\"tir.is_scheduled\", 1)\n\n        n, c, orig_height, orig_width = image.shape\n        crop_height = crop_size[\"height\"]\n        crop_width = crop_size[\"width\"]\n\n        top = (orig_height - crop_height) // 2\n        bottom = orig_height - top\n\n        left = (orig_width - crop_width) // 2\n        right = orig_width - left\n\n        out = op.tensor_ir_op(\n            create_crop_func(image.dtype),\n            \"crop\",\n            [image, top, bottom, left, right],\n            [Tensor.placeholder([n, c, crop_height, crop_width], image.dtype)],\n        )\n        return out\n\n    def rescale(self, image: Tensor, rescale_factor=1 / 255.0, o_dtype=\"float32\"):\n        assert 4 == image.ndim, \"image should be 4D data tensor\"\n        assert 3 == image.shape[1], \"image layout should be NCHW\"\n\n        def create_rescale_func(rescale_factor, dtype, o_dtype):\n            @T.prim_func\n            def rescale_func(image: T.handle, out: T.handle):\n                T.func_attr({\"op_pattern\": 8, \"tir.noalias\": True, \"tir.is_scheduled\": 1})\n                n, c, h, w = T.int64(), T.int64(), T.int64(), T.int64()\n                image_buf = T.match_buffer(image, (n, c, h, w), dtype=dtype)\n                out_buf = T.match_buffer(out, (n, c, h, w), dtype=o_dtype)\n\n                for n_idx in T.thread_binding(n, thread=\"blockIdx.x\"):\n                    for c_idx in T.thread_binding(c, thread=\"blockIdx.y\"):\n                        for h_idx, w_idx in T.grid(h, w):\n                            with T.sblock(\"rescale\"):\n                                T.reads(image_buf[n_idx, c_idx, h_idx, w_idx])\n                                T.writes(out_buf[n_idx, c_idx, h_idx, w_idx])\n                                if h_idx < h and w_idx < w:\n                                    out_buf[n_idx, c_idx, h_idx, w_idx] = (\n                                        T.cast(\n                                            image_buf[n_idx, c_idx, h_idx, w_idx],\n                                            o_dtype,\n                                        )\n                                        * rescale_factor\n                                    )\n\n            sch = s_tir.Schedule(rescale_func)\n            self.apply_schedule(sch, sch.get_sblock(\"rescale\"))\n            return sch.mod[\"main\"].with_attr(\"tir.is_scheduled\", 1)\n\n        out = op.tensor_ir_op(\n            create_rescale_func(rescale_factor, image.dtype, o_dtype),\n            \"rescale\",\n            [image],\n            [Tensor.placeholder(image.shape, o_dtype)],\n        )\n        return out\n\n    def normalize(self, image: Tensor, o_dtype=\"float32\"):\n        assert 4 == image.ndim, \"image should be 4D data tensor\"\n        assert 3 == image.shape[1], \"image layout should be NCHW\"\n\n        def create_normalize_func(dtype, o_dtype):\n            @T.prim_func\n            def normalize_func(image: T.handle, out: T.handle):\n                n, c, h, w = T.int64(), T.int64(), T.int64(), T.int64()\n                image_buf = T.match_buffer(image, (n, c, h, w), dtype=dtype)\n                out_buf = T.match_buffer(out, (n, c, h, w), dtype=o_dtype)\n                mean = _var(o_dtype, 3)\n                stddev = _var(o_dtype, 3)\n\n                for n_idx in T.thread_binding(n, thread=\"blockIdx.x\"):\n                    for c_idx in T.thread_binding(c, thread=\"blockIdx.y\"):\n                        for h_idx, w_idx in T.grid(h, w):\n                            with T.sblock(\"normalize\"):\n                                T.reads(\n                                    image_buf[n_idx, c_idx, h_idx, w_idx],\n                                    mean[c_idx],\n                                    stddev[c_idx],\n                                )\n                                T.writes(out_buf[n_idx, c_idx, h_idx, w_idx])\n                                with T.init():\n                                    mean[0] = 0.48145466\n                                    stddev[0] = 0.26862954\n                                    mean[1] = 0.4578275\n                                    stddev[1] = 0.26130258\n                                    mean[2] = 0.40821073\n                                    stddev[2] = 0.27577711\n                                if h_idx < h and w_idx < w:\n                                    out_buf[n_idx, c_idx, h_idx, w_idx] = (\n                                        T.cast(\n                                            image_buf[n_idx, c_idx, h_idx, w_idx],\n                                            o_dtype,\n                                        )\n                                        - mean[c_idx]\n                                    ) / stddev[c_idx]\n\n            sch = s_tir.Schedule(normalize_func)\n            self.apply_schedule(sch, sch.get_sblock(\"normalize\"))\n            return sch.mod[\"main\"].with_attr(\"tir.is_scheduled\", 1)\n\n        out = op.tensor_ir_op(\n            create_normalize_func(image.dtype, o_dtype),\n            \"normalize\",\n            [image],\n            [Tensor.placeholder(image.shape, o_dtype)],\n        )\n        return out\n\n    def pad(self, image: Tensor, dtype=\"uint8\"):\n        assert 4 == image.ndim, \"image should be 4D data tensor\"\n        assert 3 == image.shape[1], \"image layout should be NCHW\"\n\n        def create_pad_func(l, r, fill=255):\n            @T.prim_func\n            def pad_func(image: T.handle, out: T.handle, t: T.int64(), b: T.int64()):\n                T.func_attr({\"op_pattern\": 8, \"tir.noalias\": True, \"tir.is_scheduled\": 1})\n                n, c, h, w = T.int64(), T.int64(), T.int64(), T.int64()\n                image_buf = T.match_buffer(image, (n, c, h, w), dtype=dtype)\n                out_buf = T.match_buffer(out, (n, c, h + t + b, w + l + r), dtype=dtype)\n                out_h = h + t + b\n                out_w = w + l + r\n\n                for n_idx in T.thread_binding(n, thread=\"blockIdx.x\"):\n                    for c_idx in T.thread_binding(c, thread=\"blockIdx.y\"):\n                        for h_idx, w_idx in T.grid(out_h, out_w):\n                            with T.sblock(\"pad\"):\n                                T.reads(image_buf[n_idx, c_idx, h_idx, w_idx])\n                                T.writes(out_buf[n_idx, c_idx, h_idx, w_idx])\n                                if h_idx < t or h_idx > h + b or w_idx < l or w_idx > w + r:\n                                    out_buf[n_idx, c_idx, h_idx, w_idx] = fill\n                                else:\n                                    out_buf[n_idx, c_idx, h_idx, w_idx] = image_buf[\n                                        n_idx, c_idx, h_idx - t, w_idx - l\n                                    ]\n\n            sch = s_tir.Schedule(pad_func)\n            self.apply_schedule(sch, sch.get_sblock(\"pad\"))\n            return sch.mod[\"main\"].with_attr(\"tir.is_scheduled\", 1)\n\n        h = image.shape[2]\n        tar = tir.truncdiv(h + 335, 336) * 336\n        t = tir.div(tar - h, 2)\n        b = tar - h - t\n        l = 0\n        r = 0\n\n        n, c, h, w = image.shape\n        out = op.tensor_ir_op(\n            create_pad_func(l, r),\n            \"pad\",\n            [image, t, b],\n            [Tensor.placeholder((n, c, tar, w), image.dtype)],\n        )\n        return out\n\n    def preprocess(self, pixel_values):\n        return pixel_values\n"
  },
  {
    "path": "python/mlc_llm/nn/__init__.py",
    "content": "\"\"\"Common `nn.Modules` used to define LLMs in this project.\"\"\"\n\nfrom .expert import MixtralExperts\nfrom .kv_cache import PagedKVCache, RopeMode\n"
  },
  {
    "path": "python/mlc_llm/nn/expert.py",
    "content": "\"\"\"An nn.Module that represents MoE experts\"\"\"\n\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor\n\nfrom mlc_llm.op import cutlass, extern, ft_gemm, moe_matmul\n\n\nclass MixtralExperts(nn.Module):\n    \"\"\"Mixtral experts\"\"\"\n\n    def __init__(self, num_local_experts, in_features, out_features, tensor_parallel_shards=1):\n        self.num_local_experts = num_local_experts\n        self.in_features = in_features\n        self.out_features = out_features\n        self.weight = nn.Parameter((num_local_experts, out_features, in_features))\n        self.dtype = \"float32\"\n        self.tensor_parallel_shards = tensor_parallel_shards\n\n    def forward(self, x: Tensor, indptr: Tensor):  # pylint: disable=invalid-name,missing-docstring\n        assert x.ndim == 2\n        if indptr.ndim == 2:\n            assert indptr.shape[0] == 1\n            return moe_matmul.gemv(x, self.weight, indptr)\n        assert indptr.ndim == 1\n        if extern.get_store().cutlass_group_gemm and self.dtype in [\n            \"float16\",\n            \"bfloat16\",\n        ]:\n            return cutlass.group_gemm(x, self.weight, indptr)\n        if extern.get_store().faster_transformer and self.dtype == \"float16\":\n            return ft_gemm.faster_transformer_moe_gemm(x, self.weight, indptr)\n        return moe_matmul.group_gemm(x, self.weight, indptr)\n"
  },
  {
    "path": "python/mlc_llm/nn/kv_cache.py",
    "content": "\"\"\"Attention KV cache modeling.\"\"\"\n\n# pylint: disable=too-many-statements,too-many-lines,too-many-arguments\nimport json\nfrom typing import Any, Dict, List, Literal, Optional, Union\n\nimport numpy as np\nfrom tvm import relax as rx\nfrom tvm import tir\nfrom tvm.relax.frontend.nn.llm.kv_cache import PagedKVCache as TVMPagedKVCache\nfrom tvm.relax.frontend.nn.llm.kv_cache import RopeMode\n\n\nclass PagedKVCache(TVMPagedKVCache):  # pylint: disable=too-few-public-methods\n    \"\"\"The Paged KV Cache used in LLM batching for efficient attention computation.\"\"\"\n\n    @staticmethod\n    def create_generic(  # pylint: disable=too-many-locals\n        attn_kind: Union[Literal[\"mha\", \"mla\"], List[Literal[\"mha\", \"mla\", \"mha_sliding\"]]],\n        max_batch_size: tir.Var,\n        max_total_seq_len: tir.Var,\n        prefill_chunk_size: tir.Var,\n        page_size: tir.Var,\n        support_sliding_window: tir.Var,\n        num_hidden_layers: int,\n        num_attention_heads: int,\n        num_key_value_heads: int,\n        qk_head_dim: int,\n        v_head_dim: int,\n        rope_mode: RopeMode,\n        rope_scale: int,\n        rope_theta: int,\n        dtype: str,\n        mla_original_qk_head_dim: int = 0,\n        mla_original_v_head_dim: int = 0,\n        rotary_dim: Optional[int] = None,\n        rope_scaling: Optional[Dict[str, Any]] = None,\n        rope_ext_factors: Optional[List[int]] = None,\n        layer_partition: Optional[List[int]] = None,\n        enable_disaggregation: bool = False,\n        name: str = \"paged_kv_cache\",\n    ) -> \"PagedKVCache\":\n        \"\"\"The generic function of creating a multi-head attention PagedKVCache,\n        which will be rewritten by functions in compilation pipeline.\n        \"\"\"\n        if rotary_dim is None:\n            rotary_dim = qk_head_dim\n        if rope_scaling is None:\n            rope_scaling = {}\n        if layer_partition is None:\n            layer_partition = [0, num_hidden_layers]\n        if isinstance(attn_kind, List):\n            rx_attn_kind = [rx.StringImm(layer_kind) for layer_kind in attn_kind]\n        else:\n            rx_attn_kind = rx.StringImm(attn_kind)\n        return PagedKVCache(\n            _expr=rx.call_pure_packed(\n                \"mlc.create_paged_kv_cache_generic\",\n                rx_attn_kind,\n                rx.ShapeExpr(\n                    [\n                        max_batch_size,\n                        max_total_seq_len,\n                        prefill_chunk_size,\n                        page_size,\n                        support_sliding_window,\n                    ]\n                ),\n                rx.ShapeExpr(layer_partition),\n                rx.PrimValue(num_hidden_layers),\n                rx.PrimValue(num_attention_heads),\n                rx.PrimValue(num_key_value_heads),\n                rx.PrimValue(qk_head_dim),\n                rx.PrimValue(v_head_dim),\n                rx.PrimValue(mla_original_qk_head_dim),\n                rx.PrimValue(mla_original_v_head_dim),\n                rx.PrimValue(rope_mode),\n                rx.PrimValue(rope_scale),\n                rx.PrimValue(rope_theta),\n                rx.StringImm(json.dumps(rope_scaling)),\n                (\n                    rx.const(np.array(rope_ext_factors, \"float32\"))\n                    if rope_ext_factors is not None\n                    else rx.PrimValue(0)\n                    # NOTE: since relax does not have \"Optional\" type, we use PrimValue(0)\n                    # to represent \"undefined\".\n                ),\n                rx.PrimValue(rotary_dim),\n                rx.PrimValue(int(enable_disaggregation)),\n                rx.DataTypeImm(dtype),\n                sinfo_args=rx.ObjectStructInfo(),\n            ),\n            _name=name,\n        )\n"
  },
  {
    "path": "python/mlc_llm/nn/rnn_state.py",
    "content": "\"\"\"RNN State modeling.\"\"\"\n\nfrom typing import Sequence, Union\n\nfrom tvm import relax as rx\nfrom tvm import tir\nfrom tvm.relax.frontend.nn import Object, Tensor\nfrom tvm.script import tir as T\n\n\nclass RNNState(Object):\n    \"\"\"The RNN State used in Space State Models\"\"\"\n\n    @staticmethod\n    def create(\n        max_batch_size: tir.Var,\n        num_hidden_layers: int,\n        max_history: int,\n        init_values: Sequence[Tensor],\n        name: str = \"rnn_state\",\n    ) -> \"RNNState\":\n        \"\"\"Create a RNN state object.\n\n        Parameters\n        ----------\n        max_batch_size : tir.Var\n            The maximum batch size.\n        num_hidden_layers : int\n            The number of hidden layers.\n        max_history : int\n            The maximum history length.\n        init_values : Sequence[Tensor]\n            The initial values of the RNN state.\n        \"\"\"\n\n        bb = rx.BlockBuilder.current()\n        state_infos = [(v.shape, v.dtype) for v in init_values]\n\n        f_gets = [\n            bb.add_func(\n                RNNState.create_get_func(shape, dtype, max_batch_size, max_history, id),\n                f\"rnn_state_get_{id}\",\n            )\n            for id, (shape, dtype) in enumerate(state_infos)\n        ]\n        f_sets = [\n            bb.add_func(\n                RNNState.create_set_func(shape, dtype, max_batch_size, max_history, id),\n                f\"rnn_state_set_{id}\",\n            )\n            for id, (shape, dtype) in enumerate(state_infos)\n        ]\n\n        ret = RNNState(\n            _expr=rx.call_pure_packed(\n                \"vm.builtin.rnn_state_create\",\n                rx.PrimValue(num_hidden_layers),\n                max_batch_size,\n                max_history,\n                f_gets,\n                f_sets,\n                [v._expr for v in init_values],  # pylint: disable=protected-access\n                sinfo_args=[rx.ObjectStructInfo()],\n            ),\n            _name=name,\n        )\n        return ret\n\n    def get(\n        self,\n        layer_id: int,\n        state_id: int,\n        shape: Sequence[tir.PrimExpr],\n        dtype: str,\n    ) -> Tensor:\n        \"\"\"Get the state of the RNN layer.\n\n        - If there is only one sequence, we can directly use the storage memory,\n        without copying the data.\n        - If there are multiple sequences, we need to copy the data to get a contiguous\n        memory.\n\n        Parameters\n        ----------\n        layer_id : int\n            The layer id.\n        state_id : int\n            The state id.\n        shape : Sequence[tir.PrimExpr]\n            The shape of the state tensor.\n        dtype: str\n            The data type of the state tensor.\n\n        Returns\n        -------\n        Tensor\n            The state tensor, with shape `(batch_size, *state_size)`.\n        \"\"\"\n        bb = rx.BlockBuilder.current()\n\n        return Tensor(\n            _expr=bb.emit(\n                rx.call_dps_packed(\n                    \"vm.builtin.rnn_state_get\",\n                    [self._expr, layer_id, state_id],\n                    out_sinfo=rx.TensorStructInfo(shape, dtype),\n                )\n            )\n        )\n\n    def set(self, layer_id: int, state_id: int, value: Tensor) -> \"RNNState\":\n        \"\"\"Set the state of the RNN layer.\n\n        Parameters\n        ----------\n        layer_id : int\n            The layer id.\n        state_id : int\n            The state id.\n        value : Tensor\n            The state tensor, with shape `(batch_size, *state_size)`.\n        \"\"\"\n        bb = rx.BlockBuilder.current()\n        return RNNState(\n            _expr=bb.emit(\n                rx.call_pure_packed(\n                    \"vm.builtin.rnn_state_set\",\n                    self._expr,\n                    rx.PrimValue(layer_id),\n                    rx.PrimValue(state_id),\n                    value._expr,  # pylint: disable=protected-access\n                    sinfo_args=[rx.ObjectStructInfo()],\n                )\n            ),\n            _name=\"rnn_state_set\",\n        )\n\n    @staticmethod\n    def create_get_func(\n        shape: Sequence[Union[int, tir.Var]],\n        dtype: str,\n        max_batch_size: Union[int, tir.Var],\n        max_history: Union[int, tir.Var],\n        state_id: int,\n    ) -> tir.PrimFunc:\n        \"\"\"Create the get function with given state shape.\n\n        Parameters\n        ----------\n        shape : Sequence[Union[int, tir.Var]]\n            The shape of the state tensor.\n\n        dtype: str\n            The data type of the state tensor.\n\n        max_batch_size : Union[int, tir.Var]\n            The maximum batch size.\n\n        max_history : Union[int, tir.Var]\n            The maximum history length.\n\n        state_id : int\n            The id of the state, used for naming the function.\n\n        Returns\n        -------\n        tir.PrimFunc\n            The get function.\n        \"\"\"\n\n        def _func_one_dim():\n            @T.prim_func\n            def f(\n                var_storage: T.handle,\n                var_seq_slot_ids: T.handle,\n                var_history_slot_ids: T.handle,\n                var_output: T.handle,\n            ):\n                batch_size = T.int32(is_size_var=True)\n                T.func_attr({\"global_symbol\": f\"rnn_state_get_{state_id}\"})\n\n                storage = T.match_buffer(\n                    var_storage, (max_batch_size, max_history, shape[0]), dtype\n                )\n                seq_slot_ids = T.match_buffer(var_seq_slot_ids, (batch_size,), \"int32\")\n                history_slot_ids = T.match_buffer(var_history_slot_ids, (batch_size,), \"int32\")\n                output = T.match_buffer(var_output, (batch_size, shape[0]), dtype)\n\n                for i in range(batch_size):\n                    for s in range(shape[0]):\n                        with T.sblock(\"copy\"):\n                            vi, vs = T.axis.remap(\"SS\", [i, s])\n                            seq_id: T.int32 = seq_slot_ids[vi]\n                            history_id: T.int32 = history_slot_ids[vi]\n                            output[vi, vs] = storage[seq_id, history_id, vs]\n\n            return f\n\n        def _func_high_dim():\n            # Add a wrapper function to avoid parse the following code when len(shape) = 1\n            @T.prim_func\n            def f(\n                var_storage: T.handle,\n                var_seq_slot_ids: T.handle,\n                var_history_slot_ids: T.handle,\n                var_output: T.handle,\n            ):\n                batch_size = T.int32(is_size_var=True)\n                T.func_attr({\"global_symbol\": f\"rnn_state_get_{state_id}\"})\n\n                storage = T.match_buffer(var_storage, (max_batch_size, max_history, *shape), dtype)\n                seq_slot_ids = T.match_buffer(var_seq_slot_ids, (batch_size,), \"int32\")\n                history_slot_ids = T.match_buffer(var_history_slot_ids, (batch_size,), \"int32\")\n                output = T.match_buffer(var_output, (batch_size, *shape), dtype)\n\n                for i in range(batch_size):\n                    for s in T.grid(*shape):\n                        with T.sblock(\"copy\"):\n                            vi, *vs = T.axis.remap(\"S\" * (len(shape) + 1), [i, *s])\n                            seq_id: T.int32 = seq_slot_ids[vi]\n                            history_id: T.int32 = history_slot_ids[vi]\n                            # The following line is equivalent to:\n                            # `output[vi, *vs] = storage[seq_id, history_id, *vs]`\n                            # However, unpacking operator in subscript requires Python 3.11 or newer\n                            T.buffer_store(\n                                output,\n                                T.BufferLoad(storage, [seq_id, history_id, *vs]),\n                                [vi, *vs],\n                            )\n\n            return f\n\n        return _func_one_dim() if len(shape) == 1 else _func_high_dim()\n\n    @staticmethod\n    def create_set_func(\n        shape: Sequence[Union[int, tir.Var]],\n        dtype: str,\n        max_batch_size: Union[int, tir.Var],\n        max_history: Union[int, tir.Var],\n        state_id: int,\n    ) -> tir.PrimFunc:\n        \"\"\"Create the set function with given state shape.\n\n        Parameters\n        ----------\n        shape : Sequence[Union[int, tir.Var]]\n            The shape of the state tensor.\n\n        dtype: str\n            The data type of the state tensor.\n\n        max_batch_size : Union[int, tir.Var]\n            The maximum batch size.\n\n        max_history : Union[int, tir.Var]\n            The maximum history length.\n\n        state_id : int\n            The id of the state, used for naming the function.\n\n        Returns\n        -------\n        tir.PrimFunc\n            The set function.\n        \"\"\"\n\n        def _func_one_dim():\n            @T.prim_func\n            def f(\n                var_storage: T.handle,\n                var_seq_slot_ids: T.handle,\n                var_history_slot_ids: T.handle,\n                var_data: T.handle,\n            ):\n                batch_size = T.int32(is_size_var=True)\n                T.func_attr({\"global_symbol\": f\"rnn_state_set_{state_id}\"})\n\n                storage = T.match_buffer(\n                    var_storage, (max_batch_size, max_history, shape[0]), dtype\n                )\n                seq_slot_ids = T.match_buffer(var_seq_slot_ids, (batch_size,), \"int32\")\n                history_slot_ids = T.match_buffer(var_history_slot_ids, (batch_size,), \"int32\")\n                data = T.match_buffer(var_data, (batch_size, shape[0]), dtype)\n\n                for i in range(batch_size):\n                    for s in range(shape[0]):\n                        with T.sblock(\"copy\"):\n                            vi, vs = T.axis.remap(\"SS\", [i, s])\n                            seq_id: T.int32 = seq_slot_ids[vi]\n                            history_id: T.int32 = (history_slot_ids[vi] + 1) % T.cast(\n                                max_history, \"int32\"\n                            )\n                            storage[seq_id, history_id, vs] = data[vi, vs]\n\n            return f\n\n        def _func_high_dim():\n            @T.prim_func\n            def f(\n                var_storage: T.handle,\n                var_seq_slot_ids: T.handle,\n                var_history_slot_ids: T.handle,\n                var_data: T.handle,\n            ):\n                batch_size = T.int32(is_size_var=True)\n                T.func_attr({\"global_symbol\": f\"rnn_state_set_{state_id}\"})\n\n                storage = T.match_buffer(var_storage, (max_batch_size, max_history, *shape), dtype)\n                seq_slot_ids = T.match_buffer(var_seq_slot_ids, (batch_size,), \"int32\")\n                history_slot_ids = T.match_buffer(var_history_slot_ids, (batch_size,), \"int32\")\n                data = T.match_buffer(var_data, (batch_size, *shape), dtype)\n\n                for i in range(batch_size):\n                    for s in T.grid(*shape):\n                        with T.sblock(\"copy\"):\n                            vi, *vs = T.axis.remap(\"S\" * (len(shape) + 1), [i, *s])\n                            seq_id: T.int32 = seq_slot_ids[vi]\n                            history_id: T.int32 = (history_slot_ids[vi] + 1) % T.cast(\n                                max_history, \"int32\"\n                            )\n                            # The following line is equivalent to:\n                            # `storage[seq_id, history_id, *vs] = data[vi, *vs]`\n                            # However, unpacking operator in subscript requires Python 3.11 or newer\n                            T.buffer_store(\n                                storage,\n                                T.BufferLoad(data, [vi, *vs]),\n                                [seq_id, history_id, *vs],\n                            )\n\n            return f\n\n        return _func_one_dim() if len(shape) == 1 else _func_high_dim()\n"
  },
  {
    "path": "python/mlc_llm/op/__init__.py",
    "content": "\"\"\"Extern module for compiler.\"\"\"\n\nfrom . import moe_matmul, moe_misc\nfrom .attention import attention\nfrom .batch_spec_verify import batch_spec_verify\nfrom .extern import configure, enable, get_store\nfrom .ft_gemm import faster_transformer_dequantize_gemm\nfrom .mrope import (\n    MultimodalRotaryEmbedding,\n    VisionPositionMetadata,\n    apply_multimodal_rotary_pos_emb,\n    get_mrope_position_ids,\n)\nfrom .pipeline_parallel import pipeline_stage_boundary\nfrom .top_p_pivot import top_p_pivot, top_p_renorm\n"
  },
  {
    "path": "python/mlc_llm/op/attention.py",
    "content": "\"\"\"Operators enabled by external modules.\"\"\"\n\nimport tvm\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\nfrom mlc_llm.support import logging\n\nfrom . import extern as _extern\n\nlogger = logging.getLogger(__name__)\n\n\nWARN_FLASHINFER_GROUP_SIZE = False\nWARN_FLASHINFER_HEAD_DIM = False\n\n\ndef attention(  # pylint: disable=invalid-name,too-many-locals,too-many-statements,too-many-arguments, unused-argument\n    q: nn.Tensor,\n    k: nn.Tensor,\n    v: nn.Tensor,\n    casual_mask: nn.Tensor,\n    attn_score_scaling_factor: float = 1.0,\n    qk_dtype: str = None,\n) -> nn.Tensor:\n    \"\"\"Attention with casual mask.\n\n    --- Variables ---\n    s: sequence length of the current query\n    t: total sequence length\n    d: head dimension\n    h, h_q: number of heads in query\n    h_kv: number of heads in key and value\n    b: batch size = 1\n\n    --- Shapes ---\n    q: [b, s, h_q, d]\n    k: [t, h_kv, d]\n    v: [t, h_kv, d]\n    o: [1, s, hidden = h_q * d]\n\n    --- Computation ---\n\n    .. code-block:: python\n\n        if h_kv != h_q:\n            k = k.repeat(h_q // h_kv, axis=1)\n            v = v.repeat(h_q // h_kv, axis=1)\n        q -> [b, h, s, d]\n        k, v -> [b, h, t, d]\n        attn = q @ k^T / sqrt(d) * attn_score_scaling_factor  # [b, h, s, t]\n        attn = softmax_with_mask(attn, casual_mask, axis=-1)\n        o = attn @ v  # [b, h, s, d]\n        o -> [b, s, h * d]\n\n    --- Other params ---\n    qk_dtype: if set, `matmul(Q, K, out_dtype=qk_dtype)`, (otherwise use `q.dtype` as `out_dtype`).\n        For FlashInfer, if \"float32\", sets `allow_fp16_qk_reduction` to False; otherwise no effect.\n    \"\"\"\n    assert q.ndim == 4 and k.ndim in [3, 4] and v.ndim in [3, 4]\n    b, s, h_q, d = q.shape\n    t, h_kv, _ = k.shape[-3:]\n    group_size = h_q // h_kv\n\n    def _fallback():\n        from tvm.relax.frontend.nn.llm.kv_cache import (  # pylint: disable=import-outside-toplevel\n            _attention_sequence_prefill,\n        )\n\n        nonlocal q, k, v, qk_dtype\n        if k.ndim == 3:\n            k = op.reshape(k, [b, t, h_kv, d])\n        if v.ndim == 3:\n            v = op.reshape(v, [b, t, h_kv, d])\n        if h_kv != h_q:\n            k = k.repeat(h_q // h_kv, axis=2)\n            v = v.repeat(h_q // h_kv, axis=2)\n\n        target = tvm.target.Target(\"cuda\")\n        attn_output, _ = op.tensor_ir_op(\n            _attention_sequence_prefill(  # pylint: disable=no-value-for-parameter\n                h_kv=h_kv,\n                h_q=h_q,\n                d=d,\n                dtype=q.dtype,\n                target=target,\n                sm_scale=attn_score_scaling_factor / (d**0.5),\n            ),\n            \"sequence_prefill\",\n            [q, k, v],\n            [\n                Tensor.placeholder([b, s, h_q, d], q.dtype),\n                Tensor.placeholder([b, s, h_q], q.dtype),\n            ],\n        )\n\n        output = op.reshape(attn_output, shape=(b, s, h_q * d))\n        return output\n\n    # FlashInfer Implementation\n    if (\n        _extern.get_store().flashinfer\n        and attn_score_scaling_factor == 1.0\n        and q.dtype == \"float16\"\n        and k.dtype == \"float16\"\n        and v.dtype == \"float16\"\n    ):\n        if group_size not in [1, 4, 6, 8]:\n            global WARN_FLASHINFER_GROUP_SIZE  # pylint: disable=global-statement\n            if not WARN_FLASHINFER_GROUP_SIZE:\n                WARN_FLASHINFER_GROUP_SIZE = True\n                logger.warning(\n                    \"FlashInfer only supports group size in [1, 4, 6, 8], but got %d. Skip and \"\n                    \"fallback to default implementation.\",\n                    group_size,\n                )\n            return _fallback()\n        if d not in [128]:\n            global WARN_FLASHINFER_HEAD_DIM  # pylint: disable=global-statement\n            if not WARN_FLASHINFER_HEAD_DIM:\n                WARN_FLASHINFER_HEAD_DIM = True\n                logger.warning(\n                    \"FlashInfer only supports head_dim in [128], but got %d. Skip and fallback to \"\n                    \"default implementation.\",\n                    d,\n                )\n            return _fallback()\n        rope_theta = 0.0\n        rope_scale = 1.0\n        qkv_layout = 0  # \"NHD\", N for seq_len, H for num_heads, D for head_dim\n        rotary_mode = 0  # \"kNone\"\n        casual = 1  # True\n        fp16_qk = 1  # True\n        if qk_dtype == \"float32\":\n            fp16_qk = 0  # False\n\n        # 32MB scratchpad\n        scratch = op.empty([8192 * 1024], dtype=\"float32\")  # pylint: disable=no-member\n\n        def _decode():\n            return op.extern(\n                name=\"flashinfer.single_decode\",\n                args=[\n                    q,\n                    k,\n                    v,\n                    scratch,\n                    qkv_layout,\n                    rotary_mode,\n                    rope_scale,\n                    rope_theta,\n                ],\n                out=nn.Tensor.placeholder((b, s, h_q * d), dtype=\"float16\"),\n            )\n\n        def _prefill():\n            return op.extern(\n                name=\"flashinfer.single_prefill\",\n                args=[\n                    q,\n                    k,\n                    v,\n                    scratch,\n                    casual,\n                    qkv_layout,\n                    rotary_mode,\n                    fp16_qk,\n                    rope_scale,\n                    rope_theta,\n                ],\n                out=nn.Tensor.placeholder((b, s, h_q * d), dtype=\"float16\"),\n            )\n\n        if isinstance(s, int) and s == 1:\n            func = \"decode\"\n        else:\n            func = \"prefill\"\n        return {\n            \"decode\": _decode,\n            \"prefill\": _prefill,\n        }[func]()\n\n    # Fallback Implementation\n    return _fallback()\n"
  },
  {
    "path": "python/mlc_llm/op/batch_matmul.py",
    "content": "\"\"\"Batch matmul operators\"\"\"\n\nfrom typing import Tuple\n\nfrom tvm.relax.frontend import nn\n\nfrom mlc_llm.op import cutlass\nfrom mlc_llm.quantization.block_scale_quantization import rowwise_group_quant_fp8\n\n\ndef quantized_bmm(\n    x: nn.Tensor,\n    w: nn.Tensor,\n    w_scale: nn.Tensor,\n    block_size: Tuple[int, int],\n) -> nn.Tensor:\n    \"\"\"Quantized batch matmul.\n    Currently only support CUDA backend (by using CUTLASS).\n\n    Parameters\n    ----------\n    x : nn.Tensor\n        The input tensor, with shape of [b, m, k].\n\n    w : nn.Tensor\n        The weight tensor, with shape of [b, n, k] (column major).\n\n    w_scale : nn.Tensor\n        The scale tensor, with shape of [b, n // block_size[0], k // block_size[1]].\n\n    block_size : Tuple[int, int]\n        The block size.\n\n    Returns\n    -------\n    ret : nn.Tensor\n        The output tensor, with shape of [b, m, n].\n    \"\"\"\n    x_fp8, x_scale = rowwise_group_quant_fp8(\n        x, block_size[1], w.dtype, transpose_scale=True, keep_first_batch_dim=True\n    )\n    return cutlass.fp8_groupwise_scaled_bmm(\n        x_fp8, x_scale, w, w_scale, block_size, out_dtype=x.dtype\n    )\n"
  },
  {
    "path": "python/mlc_llm/op/batch_spec_verify.py",
    "content": "\"\"\"Operators for batch verify in speculative decoding.\"\"\"\n\nfrom tvm.script import tir as T\n\n# mypy: disable-error-code=\"attr-defined,valid-type,name-defined\"\n# pylint: disable=too-many-locals,invalid-name,too-many-arguments,\n# pylint: disable=too-many-statements,line-too-long,too-many-nested-blocks,too-many-branches\n\n\ndef batch_spec_verify(vocab_size):\n    \"\"\"Batch draft verify function. This function verifies the token tree.\n\n    Before calling the function\n\n    - token_tree_parent_ptr[b] should store the root of the tree\n\n    - draft_probs[node_id, :] stores the prob that samples the correspond tree node\n    - model_probs[node_id, :] stores the prob that should be used to sample its children\n    - Please note that the storage convention difference between model_probs and draft_probs\n        draft_probs was stored on the token node, while model_probs stores on the parent.\n        This is an intentional design since we can sample different child token with different\n        proposal draft probabilities, but the ground truth model_prob is unique per parent.\n\n    After calling the function\n    - token_tree_parent_ptr[b] points to the last token accepted\n    - There should be a followup sample step that samples from model_probs[token_tree_parent_ptr[b], :]\n        This token will be appended to the token generated.\n\n    This function will inplace update model_probs if a token was rejected and renormalization is needed.\n\n    Parameters\n    ----------\n    draft_probs:\n        The draft probability attached to each tree node\n\n    draft_tokens:\n        The draft token in each node\n\n    model_probs:\n        The model proability attached to each parent\n\n    token_tree_first_child:\n        The first child of each tree node, if there is no child, it should be -1\n\n    token_tree_next_sibling\n        The next sibling of each tree node, if there is no next sibling, it should be -1\n\n    uniform_samples\n        Per node uniform sample used to check rejection\n\n    token_tree_parent_ptr:\n        Current parent ptr state\n    \"\"\"\n    TX = 1024\n\n    def _var(dtype=\"int32\"):\n        return T.sblock_alloc_buffer((1,), dtype, scope=\"local\")\n\n    # fmt: off\n    @T.prim_func(private=True)\n    def _func(\n        var_draft_probs: T.handle,\n        var_draft_tokens: T.handle,\n        var_model_probs: T.handle,\n        var_token_tree_first_child: T.handle,\n        var_token_tree_next_sibling: T.handle,\n        var_uniform_samples: T.handle,\n        var_token_tree_parent_ptr: T.handle,\n    ):\n        \"\"\"\n        [\n            blockIdx.x on batch,\n            threadIdx.x on vocab_size,\n            for loop over excessive amounts\n        ]\n        \"\"\"\n        T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": True})\n        num_nodes = T.int32(is_size_var=True)\n        nbatch = T.int32(is_size_var=True)\n\n        draft_probs = T.match_buffer(var_draft_probs, (num_nodes, vocab_size), \"float32\")\n        draft_tokens = T.match_buffer(var_draft_tokens, (num_nodes,), \"int32\")\n        model_probs = T.match_buffer(var_model_probs, (num_nodes, vocab_size), \"float32\")\n        token_tree_first_child = T.match_buffer(var_token_tree_first_child, (num_nodes,), \"int32\")\n        token_tree_next_sibling = T.match_buffer(var_token_tree_next_sibling, (num_nodes,), \"int32\")\n        uniform_samples = T.match_buffer(var_uniform_samples, (num_nodes,), \"float32\")\n        token_tree_parent_ptr = T.match_buffer(var_token_tree_parent_ptr, (nbatch,), \"int32\")\n\n        with T.sblock(\"kernel\"):\n            child_ptr = _var()\n            parent_ptr = _var()\n            child_token = _var()\n            done = _var(\"bool\")\n            psum = _var(\"float32\")\n            t0 = _var(\"float32\")\n            model_prob_local = _var(\"float32\")\n            draft_prob_local = _var(\"float32\")\n            p_child = _var(\"float32\")\n            q_child = _var(\"float32\")\n            uniform_sample = _var(\"float32\")\n\n            pred_shared = T.sblock_alloc_buffer((1,), \"bool\", scope=\"shared\")\n            pred_local = T.sblock_alloc_buffer((1,), \"bool\", scope=\"local\")\n\n            for _bx in T.thread_binding(0, nbatch, thread=\"blockIdx.x\"):\n                for _tx in T.thread_binding(0, TX, thread=\"threadIdx.x\"):\n                    with T.sblock(\"CTA\"):\n                        # batch size\n                        b = T.axis.S(nbatch, _bx)\n                        tx = T.axis.S(TX, _tx)\n\n                        parent_ptr[0] = token_tree_parent_ptr[b]\n                        child_ptr[0] = token_tree_first_child[parent_ptr[0]]\n                        done[0] = False\n\n                        while T.Not(done[0]):\n                            T.tvm_storage_sync(\"shared\") # ensure all effects last round are visible\n                            if child_ptr[0] == -1:\n                                done[0] = True\n                                T.tvm_storage_sync(\"shared\") # sync before exit\n                            else:\n                                # decide to validate current ptr\n                                if tx == 0:\n                                    child_token[0] = draft_tokens[child_ptr[0]]\n                                    p_child[0] = model_probs[parent_ptr[0], child_token[0]]\n                                    q_child[0] = draft_probs[child_ptr[0], child_token[0]]\n                                    uniform_sample[0] = uniform_samples[child_ptr[0]]\n                                    pred_shared[0] = p_child[0] >= uniform_sample[0] * q_child[0]  # use multiplication to avoid division by zero\n                                T.tvm_storage_sync(\"shared\") # make sure all read of model_probs are done\n                                pred_local[0] = pred_shared[0]\n\n                                # accept the proposal, we move to child\n                                if pred_local[0]:\n                                    parent_ptr[0] = child_ptr[0]\n                                    child_ptr[0] = token_tree_first_child[child_ptr[0]]\n                                else:\n                                    psum[0] = 0.0\n                                    # renormalize probability, predicated by stopped_expansion[b]:\n                                    for i in T.serial(T.ceildiv(vocab_size, TX)):\n                                        k = T.meta_var(i * TX + tx)\n                                        if k < vocab_size:\n                                            model_prob_local[0] = model_probs[parent_ptr[0], k]\n                                            draft_prob_local[0] = draft_probs[child_ptr[0], k]\n                                            model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], 0.0)\n                                            psum[0] += model_prob_local[0]\n\n                                    with T.sblock(\"block_cross_thread\"):\n                                        T.reads(psum[0])\n                                        T.writes(t0[0])\n                                        T.attr(\n                                            T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),\n                                            \"reduce_scope\",\n                                            T.reinterpret(\"handle\", T.uint64(0)),\n                                        )\n                                        T.tvm_thread_allreduce(T.uint32(1), psum[0], True, t0[0], tx, dtype=\"handle\")\n\n                                    if t0[0] < 1e-7:\n                                        # accept the proposal, we move to child\n                                        parent_ptr[0] = child_ptr[0]\n                                        child_ptr[0] = token_tree_first_child[child_ptr[0]]\n                                    else:\n                                        # renormalize\n                                        for i in T.serial(T.ceildiv(vocab_size, TX)):\n                                            k = T.meta_var(i * TX + tx)\n                                            if k < vocab_size:\n                                                model_prob_local[0] = model_probs[parent_ptr[0], k]\n                                                draft_prob_local[0] = draft_probs[child_ptr[0], k]\n                                                model_prob_local[0] = T.max(model_prob_local[0] - draft_prob_local[0], 0.0)\n                                                model_probs[parent_ptr[0], k] = model_prob_local[0] / t0[0]\n\n                                        child_ptr[0] = token_tree_next_sibling[child_ptr[0]]\n\n                        if tx == 0:\n                            token_tree_parent_ptr[b] = parent_ptr[0]\n    # fmt: on\n\n    return _func\n"
  },
  {
    "path": "python/mlc_llm/op/cutlass.py",
    "content": "\"\"\"Operators enabled by external modules.\"\"\"\n\nfrom typing import Optional, Tuple\n\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import op\n\n\ndef group_gemm(\n    x: nn.Tensor,\n    weight: nn.Tensor,\n    indptr: nn.Tensor,\n    scale: Optional[nn.Tensor] = None,\n    weight_dtype: Optional[str] = None,\n    out_dtype: Optional[str] = None,\n):  # pylint: disable=too-many-arguments\n    \"\"\"\n    Cutlass group gemm operator.\n\n    Parameters\n    ----------\n    x : nn.Tensor\n        The input tensor, with shape of [m, k].\n\n    weight : nn.Tensor\n        The weight tensor, with shape of [num_groups, n, k].\n\n    indptr : nn.Tensor\n        The indptr tensor, with shape of [num_groups].\n\n    scale : Optional[nn.Tensor]\n        The scale tensor, with shape of [1].\n\n    weight_dtype: Optional[str]\n        The data type of the weight tensor.\n\n    out_dtype: Optional[str]\n        The data type of the output tensor.\n\n    Returns\n    -------\n    nn.Tensor\n        The output tensor, with shape of [m, n].\n    \"\"\"\n    assert x.ndim == 2\n    assert weight.ndim == 3\n    assert indptr.ndim == 1\n    assert weight.shape[0] == indptr.shape[0]\n    assert indptr.dtype == \"int64\"\n    out_dtype = out_dtype if out_dtype else x.dtype\n    weight_dtype = weight_dtype if weight_dtype else weight.dtype\n\n    # pylint: disable=too-many-boolean-expressions\n    if x.dtype == \"float8_e5m2\" and weight_dtype == \"float8_e5m2\" and out_dtype == \"float16\":\n        func_name = \"cutlass.group_gemm_e5m2_e5m2_fp16\"\n    elif x.dtype == \"float8_e4m3fn\" and weight_dtype == \"float8_e5m2\" and out_dtype == \"float16\":\n        func_name = \"cutlass.group_gemm_e4m3_e5m2_fp16\"\n    elif x.dtype == \"float8_e4m3fn\" and weight_dtype == \"float8_e4m3fn\" and out_dtype == \"float16\":\n        func_name = \"cutlass.group_gemm_e4m3_e4m3_fp16\"\n    elif (x.dtype == \"float16\" and weight_dtype == \"float16\" and out_dtype == \"float16\") or (\n        x.dtype == \"bfloat16\" and weight_dtype == \"bfloat16\" and out_dtype == \"bfloat16\"\n    ):\n        func_name = \"cutlass.group_gemm\"\n    else:\n        raise NotImplementedError(\n            f\"Unsupported data type: x={x.dtype}, weight={weight_dtype}, out={out_dtype}\"\n        )\n    # pylint: enable=too-many-boolean-expressions\n\n    if \"float8\" in x.dtype:\n        assert scale is not None, \"scale is required for float8 input\"\n\n    workspace = op.empty((4096 * 1024,), dtype=\"uint8\", name=\"workspace\")\n\n    return op.extern(\n        func_name,\n        args=[x, weight, indptr, workspace] + ([scale] if scale is not None else []),\n        out=nn.Tensor.placeholder((x.shape[0], weight.shape[1]), dtype=out_dtype),\n    )\n\n\ndef fp8_gemm(\n    x: nn.Tensor,\n    weight: nn.Tensor,\n    scale: nn.Tensor,\n    weight_dtype: Optional[str] = None,\n    out_dtype: Optional[str] = None,\n):  # pylint: disable=too-many-arguments\n    \"\"\"\n    Cutlass fp8 gemm operator.\n\n    Parameters\n    ----------\n    x : nn.Tensor\n        The input tensor, with shape of [m, k].\n\n    weight : nn.Tensor\n        The weight tensor, with shape of [num_groups, n, k].\n\n    scale : Optional[nn.Tensor]\n        The scale tensor, with shape of [1].\n\n    weight_dtype: Optional[str]\n        The data type of the weight tensor.\n\n    out_dtype: Optional[str]\n        The data type of the output tensor.\n\n    Returns\n    -------\n    nn.Tensor\n        The output tensor, with shape of [m, n].\n    \"\"\"\n    assert x.ndim >= 2\n    assert weight.ndim == 2\n    assert scale.ndim == 1 and scale.shape[0] == 1\n    out_dtype = out_dtype if out_dtype else x.dtype\n    weight_dtype = weight_dtype if weight_dtype else weight.dtype\n\n    if x.dtype == \"float8_e5m2\" and weight_dtype == \"float8_e5m2\" and out_dtype == \"float16\":\n        func_name = \"cutlass.gemm_e5m2_e5m2_fp16\"\n    elif x.dtype == \"float8_e4m3fn\" and weight_dtype == \"float8_e5m2\" and out_dtype == \"float16\":\n        func_name = \"cutlass.gemm_e5m2_e4m3_fp16\"\n    elif x.dtype == \"float8_e4m3fn\" and weight_dtype == \"float8_e4m3fn\" and out_dtype == \"float16\":\n        func_name = \"cutlass.gemm_e4m3_e4m3_fp16\"\n    else:\n        raise NotImplementedError(\n            f\"Unsupported data type: x={x.dtype}, weight={weight_dtype}, out={out_dtype}\"\n        )\n\n    workspace = op.empty((4096 * 1024,), dtype=\"uint8\", name=\"workspace\")\n\n    return op.extern(\n        func_name,\n        args=[x, weight, workspace, scale],\n        out=nn.Tensor.placeholder((*x.shape[:-1], weight.shape[0]), dtype=out_dtype),\n    )\n\n\ndef fp8_groupwise_scaled_gemm(  # pylint: disable=too-many-arguments\n    x: nn.Tensor,\n    x_scale: nn.Tensor,\n    weight: nn.Tensor,\n    weight_scale: nn.Tensor,\n    block_size: Tuple[int, int],\n    out_dtype: str,\n):\n    \"\"\"Cutlass block-scale fp8 gemm operator.\n\n    Parameters\n    ----------\n    x : nn.Tensor\n        The input tensor, with shape of [m, k].\n\n    x_scale : nn.Tensor\n        The scale tensor, with shape of [k // block_size, m].\n\n    weight : nn.Tensor\n        The weight tensor, with shape of [n, k].\n\n    weight_scale : nn.Tensor\n        The scale tensor, with shape of [n // block_size, k // block_size].\n\n    block_size : Tuple[int, int]\n        The block size.\n\n    out_dtype : str\n        The data type of the output tensor.\n\n    Returns\n    -------\n    out : nn.Tensor\n        The output tensor, with shape of [m, n] and dtype of `out_dtype`.\n    \"\"\"\n    assert x.ndim >= 2\n    assert weight.ndim == 2\n    assert x_scale.ndim == x.ndim\n    assert weight_scale.ndim == weight.ndim\n\n    if block_size[0] != 128 or block_size[1] != 128:\n        raise ValueError(f\"block_size must be (128, 128), but got {block_size}\")\n    if x.dtype != \"float8_e4m3fn\" or weight.dtype != \"float8_e4m3fn\":\n        raise ValueError(\n            f\"x and weight must be float8_e4m3fn, but got x={x.dtype}, weight={weight.dtype}\"\n        )\n    if x_scale.dtype != \"float32\" or weight_scale.dtype != \"float32\":\n        raise ValueError(\n            \"x_scale and weight_scale must be float32, but got \"\n            f\"x_scale={x_scale.dtype}, weight_scale={weight_scale.dtype}\"\n        )\n    if out_dtype not in [\"float16\", \"bfloat16\"]:\n        raise ValueError(f\"out_dtype must be float16 or bfloat16, but got {out_dtype}\")\n\n    func_name = \"cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn\"\n    workspace = op.empty((4096 * 1024,), dtype=\"uint8\", name=\"workspace\")\n    return op.extern(\n        func_name,\n        args=[\n            x,\n            weight,\n            x_scale,\n            weight_scale,\n            workspace,\n            block_size[0],\n            block_size[1],\n        ],\n        out=nn.Tensor.placeholder((*x.shape[:-1], weight.shape[0]), dtype=out_dtype),\n    )\n\n\ndef fp8_groupwise_scaled_bmm(  # pylint: disable=too-many-arguments\n    x: nn.Tensor,\n    x_scale: nn.Tensor,\n    weight: nn.Tensor,\n    weight_scale: nn.Tensor,\n    block_size: Tuple[int, int],\n    out_dtype: str,\n):\n    \"\"\"Cutlass block-scale fp8 gemm operator.\n\n    Parameters\n    ----------\n    x : nn.Tensor\n        The input tensor, with shape of [b, m, k].\n\n    x_scale : nn.Tensor\n        The scale tensor, with shape of [b, k // block_size, m].\n\n    weight : nn.Tensor\n        The weight tensor, with shape of [b, n, k].\n\n    weight_scale : nn.Tensor\n        The scale tensor, with shape of [b, n // block_size, k // block_size].\n\n    block_size : Tuple[int, int]\n        The block size.\n\n    out_dtype : str\n        The data type of the output tensor.\n\n    Returns\n    -------\n    out : nn.Tensor\n        The output tensor, with shape of [m, n] and dtype of `out_dtype`.\n    \"\"\"\n    assert x.ndim == 3\n    assert weight.ndim == 3\n    assert x_scale.ndim == x.ndim\n    assert weight_scale.ndim == weight.ndim\n    assert x.shape[0] == x_scale.shape[0] == weight.shape[0] == weight_scale.shape[0]\n\n    if block_size[0] != 128 or block_size[1] != 128:\n        raise ValueError(f\"block_size must be (128, 128), but got {block_size}\")\n    if x.dtype != \"float8_e4m3fn\" or weight.dtype != \"float8_e4m3fn\":\n        raise ValueError(\n            f\"x and weight must be float8_e4m3fn, but got x={x.dtype}, weight={weight.dtype}\"\n        )\n    if x_scale.dtype != \"float32\" or weight_scale.dtype != \"float32\":\n        raise ValueError(\n            \"x_scale and weight_scale must be float32, but got \"\n            f\"x_scale={x_scale.dtype}, weight_scale={weight_scale.dtype}\"\n        )\n    if out_dtype not in [\"float16\", \"bfloat16\"]:\n        raise ValueError(f\"out_dtype must be float16 or bfloat16, but got {out_dtype}\")\n\n    func_name = \"cutlass.groupwise_scaled_bmm_e4m3fn_e4m3fn\"\n    workspace = op.empty((4096 * 1024,), dtype=\"uint8\", name=\"workspace\")\n    return op.extern(\n        func_name,\n        args=[\n            x,\n            weight,\n            x_scale,\n            weight_scale,\n            workspace,\n            block_size[0],\n            block_size[1],\n        ],\n        out=nn.Tensor.placeholder((x.shape[0], x.shape[1], weight.shape[1]), dtype=out_dtype),\n    )\n\n\ndef fp8_groupwise_scaled_group_gemm(  # pylint: disable=too-many-arguments,too-many-locals\n    x: nn.Tensor,\n    x_scale: nn.Tensor,\n    weight: nn.Tensor,\n    weight_scale: nn.Tensor,\n    indptr: nn.Tensor,\n    block_size: Tuple[int, int],\n    out_dtype: str,\n):\n    \"\"\"Triton block-scale fp8 group gemm operator.\n\n    Parameters\n    ----------\n    x : nn.Tensor\n        The input tensor, with shape of [m, k].\n\n    x_scale : nn.Tensor\n        The scale tensor, with shape of [m, k // block_size].\n\n    weight : nn.Tensor\n        The weight tensor, with shape of [num_experts, n, k].\n\n    weight_scale : nn.Tensor\n        The scale tensor, with shape of [num_experts, n // block_size, k // block_size].\n\n    indptr : nn.Tensor\n        The indptr tensor of group gemm, with shape of [num_experts + 1,].\n\n    block_size : Tuple[int, int]\n        The block size.\n\n    out_dtype : str\n        The data type of the output tensor.\n\n    Returns\n    -------\n    out : nn.Tensor\n        The output tensor, with shape of [m, n] and dtype of `out_dtype`.\n    \"\"\"\n    assert x.ndim >= 2\n    assert weight.ndim == 3\n    assert x_scale.ndim == x.ndim\n    assert weight_scale.ndim == weight.ndim\n    assert x.shape[-1] == weight.shape[2]\n    assert (x.shape[-1] + block_size[1] - 1) // block_size[1] == x_scale.shape[-1]\n    assert (weight.shape[2] + block_size[1] - 1) // block_size[1] == weight_scale.shape[2]\n    assert (weight.shape[1] + block_size[0] - 1) // block_size[0] == weight_scale.shape[1]\n\n    if block_size[0] != 128 or block_size[1] != 128:\n        raise ValueError(f\"block_size must be (128, 128), but got {block_size}\")\n    if x.dtype != \"float8_e4m3fn\" or weight.dtype != \"float8_e4m3fn\":\n        raise ValueError(\n            f\"x and weight must be float8_e4m3fn, but got x={x.dtype}, weight={weight.dtype}\"\n        )\n    if x_scale.dtype != \"float32\" or weight_scale.dtype != \"float32\":\n        raise ValueError(\n            \"x_scale and weight_scale must be float32, but got \"\n            f\"x_scale={x_scale.dtype}, weight_scale={weight_scale.dtype}\"\n        )\n    if out_dtype not in [\"float16\", \"bfloat16\"]:\n        raise ValueError(f\"out_dtype must be float16 or bfloat16, but got {out_dtype}\")\n\n    num_experts = weight.shape[0]\n    m = x.shape[0]\n    for i in range(1, x.ndim - 1):\n        m *= x.shape[i]\n    n = weight.shape[1]\n    k = x.shape[-1]\n    assert weight_scale.shape[0] == num_experts\n    assert indptr.ndim == 1\n    assert indptr.shape[0] == num_experts\n    assert indptr.dtype == \"int64\"\n\n    x_shape = x.shape\n    if x.ndim > 2:\n        x = x.reshape(m, k)\n        x_scale = x_scale.reshape(m, x_scale.shape[-1])\n\n    func_name = \"cutlass.groupwise_scaled_group_gemm_e4m3fn_e4m3fn\"\n    workspace = op.empty((4096 * 1024,), dtype=\"uint8\", name=\"workspace\")\n    out = op.extern(\n        func_name,\n        args=[\n            x,\n            weight,\n            x_scale,\n            weight_scale,\n            indptr,\n            workspace,\n            block_size[0],\n            block_size[1],\n        ],\n        out=nn.Tensor.placeholder((m, n), dtype=out_dtype),\n    )\n    return out.reshape(*x_shape[:-1], n) if len(x_shape) > 2 else out\n"
  },
  {
    "path": "python/mlc_llm/op/extern.py",
    "content": "\"\"\"Potential externel modules managed by MLC compilation stack.\n\nAn externl module could contain one or multiple handcrafted kernels, as long as it is provided as\nan object file (`.o`), a C++ source file (`.cc`), or a CUDA source file (`.cu`). It can be\nintegrated into the system pretty smoothly.\n\nAs examples, `flashinfer.py` contains such an example that instructs MLC to compile\n\"$tvm_home/3rdparty/flashinfer/src/tvm_wrapper.cu\" with a specific set of compilation flags and then\nlink into the generated artifact of MLC LLM. TVM PR #16247\n(https://github.com/apache/tvm/pull/16247/) provides more details of using TVM's\n`nn.SourceModule` to integrate C++ and CUDA files, and `nn.ObjectModule` to integrate object files.\n\nTo conveniently use those externel modules, MLC LLM compilation pipeline manages an extra global\nsingleton `Store: ExternalModuleStore` to store the configured modules. It is supposed to be enabled\nbefore any compilation happens, and configured during a model's `forward` method is invoked.\n\"\"\"\n\nimport dataclasses\nfrom typing import Optional\n\nfrom tvm.target import Target\n\n\n@dataclasses.dataclass\nclass ExternModuleStore:\n    \"\"\"Global store of external modules enabled during compilation.\"\"\"\n\n    configured: bool = False\n    target: Optional[Target] = None\n    flashinfer: bool = False\n    faster_transformer: bool = False\n    cutlass_group_gemm: bool = False\n    cutlass_gemm: bool = False\n\n\nSTORE: ExternModuleStore = ExternModuleStore()\n\"\"\"Singleton of `ExternModuleStore`.\"\"\"\n\n\ndef enable(target: Target, flashinfer: bool, faster_transformer: bool, cutlass: bool) -> None:\n    \"\"\"Enable external modules. It should be called before any compilation happens.\"\"\"\n    global STORE  # pylint: disable=global-statement\n    cutlass = (\n        cutlass\n        and target.kind.name == \"cuda\"\n        and target.attrs.get(\"arch\", \"\") in [\"sm_90a\", \"sm_100a\"]\n    )\n    faster_transformer = False\n    STORE = ExternModuleStore(\n        configured=False,\n        target=target,\n        flashinfer=flashinfer,\n        faster_transformer=faster_transformer,\n        cutlass_group_gemm=cutlass,\n        cutlass_gemm=cutlass,\n    )\n\n\ndef get_store() -> ExternModuleStore:\n    \"\"\"Get the global store of external modules.\"\"\"\n    return STORE\n\n\ndef configure() -> None:\n    \"\"\"Configure external modules with extra parameters. It should be called during a model's\n    `forward` method is invoked.\n\n    Parameters\n    ----------\n    \"\"\"\n    store = get_store()\n    if store.configured:\n        return\n    store.configured = True\n    if store.flashinfer or store.faster_transformer:\n        assert store.target.kind.name == \"cuda\"\n"
  },
  {
    "path": "python/mlc_llm/op/ft_gemm.py",
    "content": "\"\"\"Operators enabled by external modules.\"\"\"\n\nimport operator\nfrom functools import reduce\nfrom typing import Optional\n\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import op\n\n\ndef faster_transformer_dequantize_gemm(  # pylint: disable=too-many-arguments\n    x: nn.Tensor,\n    weight: nn.Tensor,\n    scale: nn.Tensor,\n    bias: Optional[nn.Tensor] = None,\n    activation: Optional[str] = None,\n    group_size: Optional[int] = None,\n):\n    \"\"\"\n    Faster Transformer dequantize gemm inference with CutlassFpAIntB\n\n    Parameters\n    ----------\n    x : nn.Tensor\n        The input tensor, with shape of [*m, k].\n\n    weight : nn.Tensor\n        The quantized weight data tensor, with shape of [k, n // num_elem_per_storage].\n\n    scale : nn.Tensor\n        The quantized weight scale tensor, with shape of [k // group_size, n].\n\n    bias : Optional[nn.Tensor]\n        The optional bias for matmul, with shape broadcastable to [*m, n].\n\n    group_size : Optional[int]\n        The optional group size. If not set, then using k as group size.\n\n    Returns\n    ------\n    ret: nn.Tensor\n        The output tensor of deocde matmul, with shape of [*m, n].\n    \"\"\"\n    assert x.dtype == \"float16\" and x.ndim >= 1\n    assert weight.ndim == 2\n    assert scale.dtype == \"float16\" and scale.ndim == 2\n    assert x.shape[-1] == weight.shape[0], (\n        \"Reduction dimension mismatched between x and weight, \"\n        f\"{x.shape[-1]} vs {weight.shape[0]}.\"\n    )\n    assert activation in [\n        None,\n        \"relu\",\n        \"gelu\",\n        \"silu\",\n        \"identity\",\n    ], \"Supported activations are [None, 'identity', 'gelu', 'silu', 'relu'].\"\n    activation = activation if activation else \"identity\"\n    m = reduce(operator.mul, x.shape[:-1], 1)\n    k = x.shape[-1]\n    n = scale.shape[1]\n\n    if not group_size:\n        group_size = k\n\n    if bias:\n        assert bias.dtype == \"float16\" and bias.ndim >= 1\n        bias_stride = (\n            bias.shape[-1]\n            if bias and not reduce(operator.mul, bias.shape, 1) == bias.shape[-1]\n            else 0\n        )\n        return op.extern(\n            name=\"fastertransformer.gemm_fp16_int_bias\",\n            args=[\n                x,\n                weight,\n                scale,\n                bias,\n                activation,\n                m,\n                n,\n                k,\n                group_size,\n                bias_stride,\n            ],\n            out=nn.Tensor.placeholder((*x.shape[:-1], scale.shape[1]), dtype=\"float16\"),\n        )\n    return op.extern(\n        name=\"fastertransformer.gemm_fp16_int\",\n        args=[x, weight, scale, activation, m, n, k, group_size],\n        out=nn.Tensor.placeholder((*x.shape[:-1], scale.shape[1]), dtype=\"float16\"),\n    )\n\n\ndef faster_transformer_moe_gemm(  # pylint: disable=too-many-arguments\n    x: nn.Tensor,\n    weight: nn.Tensor,\n    total_rows_before: nn.Tensor,\n):\n    \"\"\"\n    Faster Transformer moe gemm inference with CutlassFpAIntB\n\n    Parameters\n    ----------\n    x : nn.Tensor\n        The input tensor, with shape of [*m, k].\n\n    weight : nn.Tensor\n        The weight data tensor, with shape of [num_experts, n, k].\n\n    total_rows_before : nn.Tensor\n        The total rows before tensor the current expert, with shape of [num_experts]. This is the\n        same as the indptr excluding the first zero element.\n\n    Returns\n    ------\n    ret: nn.Tensor\n        The output tensor of deocde matmul, with shape of [*m, n].\n    \"\"\"\n    assert x.dtype == \"float16\" and x.ndim >= 1\n    assert weight.dtype == \"float16\" and weight.ndim == 3\n    assert x.shape[-1] == weight.shape[-1], (\n        \"Reduction dimension mismatched between x and weight, \"\n        f\"{x.shape[-1]} vs {weight.shape[-1]}.\"\n    )\n    m = reduce(operator.mul, x.shape[:-1], 1)\n    num_experts = weight.shape[0]\n    n = weight.shape[1]\n    k = x.shape[-1]\n\n    return op.extern(\n        name=\"fastertransformer.moe_gemm_fp16_fp16\",\n        args=[x, weight, total_rows_before, m, n, k, num_experts],\n        out=nn.Tensor.placeholder((*x.shape[:-1], n), dtype=\"float16\"),\n    )\n"
  },
  {
    "path": "python/mlc_llm/op/moe_matmul.py",
    "content": "\"\"\"Mixture of Experts operators\"\"\"\n\nfrom typing import Literal, Optional, Tuple\n\nfrom tvm import DataType, DataTypeCode, s_tir, tir\nfrom tvm.relax.frontend.nn import Tensor, op\nfrom tvm.script import tir as T\n\n# mypy: disable-error-code=\"attr-defined,valid-type,name-defined\"\n# pylint: disable=too-many-locals,invalid-name,too-many-arguments,too-many-statements\n\n\ndef gemv(x: Tensor, w: Tensor, indptr: Tensor) -> Tensor:\n    \"\"\"GEMV for project-in (e1-e3) or project-out (e2) in MLP.\n\n    Parameters\n    ----------\n    x : Tensor\n        For project-in, the input tensor of shape (1, in_features); and for project-out, the input\n        shape is (experts_per_tok, in_features), where `experts_per_tok` is the number of activated\n        experts per token.\n\n    w : Tensor\n        The weight tensor of shape (local_experts, out_features, in_features), where `local_experts`\n        is the total number of experts.\n\n    indptr : Tensor\n        The index pointer tensor of shape (1, experts_per_tok), where `experts_per_tok` is the\n        number of activated experts per token.\n\n    Returns\n    -------\n    out : Tensor\n        The output tensor of shape (experts_per_tok, out_features), where `experts_per_tok` is the\n        number of activated experts per token.\n    \"\"\"\n    (local_experts, out_features, in_features), dtype = w.shape, w.dtype\n    _, experts_per_tok = indptr.shape\n    x_leading_dim, _ = x.shape\n\n    def access_x(x, e, j):\n        return x[0, j] if x_leading_dim == 1 else x[e, j]\n\n    # NOTE: Currently it assumes x.dtype == w.dtype, but the constraint can be relaxed easily.\n    assert w.shape == [local_experts, out_features, in_features] and w.dtype == dtype\n    assert x.shape == [x_leading_dim, in_features] and x.dtype == dtype\n    assert indptr.shape == [1, experts_per_tok] and indptr.dtype == \"int32\"\n    assert x_leading_dim in [1, experts_per_tok]\n\n    @T.prim_func(private=True)\n    def _func(\n        x: T.Buffer((x_leading_dim, in_features), dtype),\n        w: T.Buffer((local_experts, out_features, in_features), dtype),\n        indptr: T.Buffer((1, experts_per_tok), \"int32\"),\n        o: T.Buffer((experts_per_tok, out_features), dtype),\n    ):\n        T.func_attr({\"op_pattern\": 4, \"tir.noalias\": True})  # kOutEWiseFusable\n        for e in T.thread_binding(experts_per_tok, thread=\"blockIdx.y\"):\n            with T.sblock(\"gemv_o\"):\n                e = T.axis.spatial(experts_per_tok, e)\n                T.reads(x[:, :], w[indptr[0, e], :, :], indptr[0, e])\n                T.writes(o[e, :])\n                for i1, i2 in T.grid(out_features, in_features):\n                    with T.sblock(\"gemv\"):\n                        i, j = T.axis.remap(\"SR\", [i1, i2])\n                        with T.init():\n                            o[e, i] = T.cast(T.float16(0), dtype)\n                        o[e, i] += access_x(x, e, j) * w[indptr[0, e], i, j]\n\n    return op.tensor_ir_op(\n        _func,\n        \"moe_gemv\",\n        args=[x, w, indptr],\n        out=Tensor.placeholder([experts_per_tok, out_features], dtype),\n    )\n\n\ndef dequantize_gemv(  # pylint: disable=too-many-arguments\n    x: Tensor,\n    w: Tensor,\n    scale: Tensor,\n    indptr: Tensor,\n    quantize_dtype: str,\n    group_size: int,\n) -> Tensor:\n    \"\"\"GEMV for project-in (e1-e3) or project-out (e2) in MLP but the weight is quantized.\n    It needs to be dequantized before the GEMV computation.\n\n    Parameters\n    ----------\n    x : Tensor\n        For project-in, the input tensor of shape (1, in_features); and for project-out, the input\n        shape is (experts_per_tok, in_features), where `experts_per_tok` is the number of activated\n        experts per token.\n\n    w : Tensor\n        The quantized weight tensor of shape (local_experts, out_features, in_features // n),\n        where n is the number of elements per storage dtype, e.g. if the storage dtype is uint32,\n        and the quantize dtype is int4, then n is 8.\n        `local_experts` is the total number of experts including activated and non-active ones.\n\n    scale : Tensor\n        The scale tensor of shape (local_experts, out_features, in_features // group_size), where\n        `local_experts` is the total number of experts including activated and non-active ones.\n\n    indptr : Tensor\n        The index pointer tensor of shape (1, experts_per_tok), where `experts_per_tok` is the\n        number of activated experts per token.\n\n    quantize_dtype : str\n        The quantize dtype of the weight tensor, which is usually int3, int4 or fp8, etc.\n\n    group_size : int\n        The number of elements in each quantization group, e.g. 32 or 128.\n\n    Returns\n    -------\n    out : Tensor\n        The output tensor of shape (experts_per_tok, out_features), where `experts_per_tok` is the\n        number of activated experts per token.\n    \"\"\"\n    (x_leading_dim, in_features), model_dtype = x.shape, x.dtype\n    (local_experts, out_features, _), storage_dtype = w.shape, w.dtype\n    _, experts_per_tok = indptr.shape\n    quantize_dtype_bits = DataType(quantize_dtype).bits\n    num_elem_per_storage = DataType(storage_dtype).bits // quantize_dtype_bits\n    num_group = (in_features + group_size - 1) // group_size\n    num_storage = group_size // num_elem_per_storage * num_group\n\n    def _dequantize(w, s, e, i, j):\n        tir_bin_mask = tir.const((2**quantize_dtype_bits) - 1, storage_dtype)\n        tir_max_int = tir.const((2 ** (quantize_dtype_bits - 1)) - 1, model_dtype)\n        w = w[e, i, j // num_elem_per_storage]\n        s = s[e, i, j // group_size]\n        shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype)\n        w = tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask).astype(model_dtype)\n        return (w - tir_max_int) * s\n\n    def access_x(x, e, j):\n        return x[0, j] if x_leading_dim == 1 else x[e, j]\n\n    assert x.shape == [x_leading_dim, in_features] and x.dtype == model_dtype\n    assert w.shape == [local_experts, out_features, num_storage] and w.dtype == storage_dtype\n    assert scale.shape == [local_experts, out_features, num_group] and scale.dtype == model_dtype\n    assert indptr.shape == [1, experts_per_tok] and indptr.dtype == \"int32\"\n    assert x_leading_dim in [1, experts_per_tok]\n\n    @T.prim_func(private=True)\n    def _func(\n        x: T.Buffer((x_leading_dim, in_features), model_dtype),\n        w: T.Buffer((local_experts, out_features, num_storage), storage_dtype),\n        scale: T.Buffer((local_experts, out_features, num_group), model_dtype),\n        indptr: T.Buffer((1, experts_per_tok), \"int32\"),\n        o: T.Buffer((experts_per_tok, out_features), model_dtype),\n    ):\n        T.func_attr({\"op_pattern\": 4, \"tir.noalias\": True})  # kOutEWiseFusable\n        for expert_id in T.thread_binding(experts_per_tok, thread=\"blockIdx.y\"):\n            with T.sblock(\"gemv_o\"):\n                e = T.axis.spatial(experts_per_tok, expert_id)\n                y = T.sblock_alloc_buffer((out_features, in_features), model_dtype)\n                for i1, i2 in T.grid(out_features, in_features):\n                    with T.sblock(\"dequantize\"):\n                        i, j = T.axis.remap(\"SS\", [i1, i2])\n                        y[i, j] = _dequantize(w, scale, indptr[0, e], i, j)\n                for i1, i2 in T.grid(out_features, in_features):\n                    with T.sblock(\"gemv\"):\n                        i, j = T.axis.remap(\"SR\", [i1, i2])\n                        with T.init():\n                            o[e, i] = T.cast(T.float16(0), model_dtype)\n                        o[e, i] += access_x(x, e, j) * y[i, j]\n\n    return op.tensor_ir_op(\n        _func,\n        \"moe_dequantize_gemv\",\n        args=[x, w, scale, indptr],\n        out=Tensor.placeholder([experts_per_tok, out_features], model_dtype),\n    )\n\n\ndef dequantize_float8_gemv(\n    x: Tensor,\n    w: Tensor,\n    scale: Optional[Tensor],\n    indptr: Tensor,\n    quantize_dtype: Literal[\"float8_e5m2\", \"float8_e4m3fn\"],\n) -> Tensor:\n    \"\"\"GEMV for project-in (e1-e3) or project-out (e2) in MLP but the weight is quantized in\n    fp8 e5m2 or e4m3. It needs to be dequantized before the GEMV computation.\n\n    Parameters\n    ----------\n    x : Tensor\n        For project-in, the input tensor of shape (1, in_features); and for project-out, the input\n        shape is (experts_per_tok, in_features), where `experts_per_tok` is the number of activated\n        experts per token.\n\n    w : Tensor\n        The quantized weight tensor of shape (local_experts, out_features, in_features)\n\n    scale : Optional[Tensor]\n        The optional scale tensor of shape (1,)\n\n    indptr : Tensor\n        The index pointer tensor of shape (1, experts_per_tok), where `experts_per_tok` is the\n        number of activated experts per token.\n\n    quantize_dtype : Literal[\"float8_e5m2\", \"float8_e4m3fn\"]\n        The quantize dtype of the weight tensor, which is either float8_e5m2 or float8_e4m3fn.\n    \"\"\"\n    (x_leading_dim, in_features), model_dtype = x.shape, x.dtype\n    (local_experts, out_features, _), storage_dtype = w.shape, w.dtype\n    _, experts_per_tok = indptr.shape\n    quantize_dtype_bits = DataType(quantize_dtype).bits\n    num_elem_per_storage = DataType(storage_dtype).bits // quantize_dtype_bits\n    num_storage = tir.ceildiv(in_features, num_elem_per_storage)\n\n    def _dequantize(w, s, e, i, j):\n        if num_elem_per_storage == 1:\n            w = tir.reinterpret(quantize_dtype, w[e, i, j])\n        else:\n            assert DataType(storage_dtype).type_code == DataTypeCode.UINT\n            tir_bin_mask = tir.const((2**quantize_dtype_bits) - 1, storage_dtype)\n            w = w[e, i, j // num_elem_per_storage]\n            shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype)\n            w = tir.reinterpret(\n                quantize_dtype,\n                tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask).astype(\"uint8\"),\n            )\n        w = w.astype(model_dtype)\n        if s is not None:\n            w = w * s[0]\n        return w\n\n    def access_x(x, e, j):\n        return x[0, j] if x_leading_dim == 1 else x[e, j]\n\n    @T.prim_func(private=True)\n    def _func_with_scale(\n        x: T.Buffer((x_leading_dim, in_features), model_dtype),\n        w: T.Buffer((local_experts, out_features, num_storage), storage_dtype),\n        scale: T.Buffer((1,), \"float32\"),\n        indptr: T.Buffer((1, experts_per_tok), \"int32\"),\n        o: T.Buffer((experts_per_tok, out_features), model_dtype),\n    ):\n        T.func_attr({\"op_pattern\": 4, \"tir.noalias\": True})  # kOutEWiseFusable\n        for expert_id in T.thread_binding(experts_per_tok, thread=\"blockIdx.y\"):\n            with T.sblock(\"gemv_o\"):\n                e = T.axis.spatial(experts_per_tok, expert_id)\n                y = T.sblock_alloc_buffer((out_features, in_features), model_dtype)\n                for i1, i2 in T.grid(out_features, in_features):\n                    with T.sblock(\"dequantize\"):\n                        i, j = T.axis.remap(\"SS\", [i1, i2])\n                        y[i, j] = _dequantize(w, scale, indptr[0, e], i, j)\n                for i1, i2 in T.grid(out_features, in_features):\n                    with T.sblock(\"gemv\"):\n                        i, j = T.axis.remap(\"SR\", [i1, i2])\n                        with T.init():\n                            o[e, i] = T.cast(T.float16(0), model_dtype)\n                        o[e, i] += access_x(x, e, j) * y[i, j]\n\n    @T.prim_func(private=True)\n    def _func_without_scale(\n        x: T.Buffer((x_leading_dim, in_features), model_dtype),\n        w: T.Buffer((local_experts, out_features, num_storage), storage_dtype),\n        indptr: T.Buffer((1, experts_per_tok), \"int32\"),\n        o: T.Buffer((experts_per_tok, out_features), model_dtype),\n    ):\n        T.func_attr({\"op_pattern\": 4, \"tir.noalias\": True})  # kOutEWiseFusable\n        for expert_id in T.thread_binding(experts_per_tok, thread=\"blockIdx.y\"):\n            with T.sblock(\"gemv_o\"):\n                e = T.axis.spatial(experts_per_tok, expert_id)\n                y = T.sblock_alloc_buffer((out_features, in_features), model_dtype)\n                for i1, i2 in T.grid(out_features, in_features):\n                    with T.sblock(\"dequantize\"):\n                        i, j = T.axis.remap(\"SS\", [i1, i2])\n                        y[i, j] = _dequantize(w, None, indptr[0, e], i, j)\n                for i1, i2 in T.grid(out_features, in_features):\n                    with T.sblock(\"gemv\"):\n                        i, j = T.axis.remap(\"SR\", [i1, i2])\n                        with T.init():\n                            o[e, i] = T.cast(T.float16(0), model_dtype)\n                        o[e, i] += access_x(x, e, j) * y[i, j]\n\n    if scale is not None:\n        return op.tensor_ir_op(\n            _func_with_scale,\n            \"moe_dequantize_gemv\",\n            args=[x, w, scale, indptr],\n            out=Tensor.placeholder([experts_per_tok, out_features], model_dtype),\n        )\n    return op.tensor_ir_op(\n        _func_without_scale,\n        \"moe_dequantize_gemv\",\n        args=[x, w, indptr],\n        out=Tensor.placeholder([experts_per_tok, out_features], model_dtype),\n    )\n\n\ndef dequantize_block_scale_float8_gemv(\n    x: Tensor,\n    w: Tensor,\n    w_scale: Tensor,\n    expert_indices: Tensor,\n    block_size: Tuple[int, int],\n    out_dtype: str,\n) -> Tensor:\n    \"\"\"GEMV for project-in (e1-e3) or project-out (e2) in MLP but the weight is quantized in\n    fp8 e5m2 or e4m3. It needs to be dequantized before the GEMV computation.\n\n    Parameters\n    ----------\n    x : Tensor\n        For project-in, the input tensor of shape (1, in_features); and for project-out, the input\n        shape is (experts_per_tok, in_features), where `experts_per_tok` is the number of activated\n        experts per token.\n\n    w : Tensor\n        The quantized weight tensor of shape (local_experts, out_features, in_features)\n\n    w_scale : Tensor\n        The scale tensor of shape\n        (local_experts, out_features // block_size[0], in_features // block_size[1])\n\n    indptr : Tensor\n        The index pointer tensor of shape (1, experts_per_tok), where `experts_per_tok` is the\n        number of activated experts per token.\n\n    block_size : Tuple[int, int]\n        The block size of the weight tensor.\n\n    out_dtype : str\n        The output dtype of the GEMV computation.\n    \"\"\"\n    x_leading_dim, in_features = x.shape\n    local_experts, out_features, k = w.shape\n    _, experts_per_tok = expert_indices.shape\n    model_dtype = x.dtype\n    quantize_dtype = w.dtype\n\n    assert out_features % block_size[0] == 0\n    assert k % block_size[1] == 0\n\n    def _dequantize(w, s, e, i, j):\n        return w[e, i, j].astype(model_dtype) * s[e, i // block_size[0], j // block_size[1]].astype(\n            model_dtype\n        )\n\n    def load_x(x, e, j):\n        return x[0, j] if x_leading_dim == 1 else x[e, j]\n\n    @T.prim_func(private=True)\n    def _func(\n        x: T.Buffer((x_leading_dim, in_features), model_dtype),\n        w: T.Buffer((local_experts, out_features, k), quantize_dtype),\n        w_scale: T.Buffer(\n            (local_experts, out_features // block_size[0], k // block_size[1]),\n            \"float32\",\n        ),\n        expert_indices: T.Buffer((1, experts_per_tok), \"int32\"),\n        o: T.Buffer((experts_per_tok, out_features), out_dtype),\n    ):\n        T.func_attr({\"op_pattern\": 4, \"tir.noalias\": True})  # kOutEWiseFusable\n        for expert_id in T.thread_binding(experts_per_tok, thread=\"blockIdx.y\"):\n            with T.sblock(\"gemv_o\"):\n                e = T.axis.spatial(experts_per_tok, expert_id)\n                y = T.sblock_alloc_buffer((out_features, in_features), model_dtype)\n                for i1, i2 in T.grid(out_features, in_features):\n                    with T.sblock(\"dequantize\"):\n                        i, j = T.axis.remap(\"SS\", [i1, i2])\n                        y[i, j] = _dequantize(w, w_scale, expert_indices[0, e], i, j)\n                for i1, i2 in T.grid(out_features, in_features):\n                    with T.sblock(\"gemv\"):\n                        i, j = T.axis.remap(\"SR\", [i1, i2])\n                        with T.init():\n                            o[e, i] = T.cast(T.float16(0), out_dtype)\n                        o[e, i] += (load_x(x, e, j) * y[i, j]).astype(out_dtype)\n\n    return op.tensor_ir_op(\n        _func,\n        \"moe_dequantize_gemv\",\n        args=[x, w, w_scale, expert_indices],\n        out=Tensor.placeholder([experts_per_tok, out_features], out_dtype),\n    )\n\n\ndef group_gemm(x: Tensor, w: Tensor, indptr: Tensor):  # pylint: disable=too-many-statements\n    \"\"\"Group GEMM in MoE models.\n\n    Parameters\n    ----------\n    x : Tensor\n        Input tensor of shape (batch_size, in_features), where `batch_size` could be dynamic shape.\n\n    w : Tensor\n        Weight tensor of shape (num_local_experts, out_features, in_features).\n        `w[i, :, :]` is the weight matrix for the `i`-th local expert.\n\n    indptr : Tensor\n        Index pointer tensor of shape (num_local_experts + 1, ).\n        `x[indptr[a] : indptr[a + 1]]` is the input for the `i`-th local expert.\n\n    Returns\n    -------\n    out : Tensor\n        Output tensor of shape (batch_size, out_features).\n    \"\"\"\n    # NOTE: Currently it assumes x.dtype == w.dtype, but the constraint can be relaxed easily.\n    (num_local_experts, out_features, in_features), dtype = w.shape, w.dtype\n\n    assert x.shape[1:] == [in_features] and x.dtype == dtype\n    assert indptr.shape == [num_local_experts + 1] and indptr.dtype == \"int32\"\n\n    Ne, N, K = num_local_experts, out_features, in_features\n    BLK_M, BLK_N, BLK_K = 8, 128, 32\n    TX, TY, CTA_COUNT = 8, 32, 1024\n    VEC_X, VEC_W, VEC_O, VEC_DOT = 1, 1, 1, 1\n    UNROLL = 64\n    STORAGE_ALIGN = False\n    assert BLK_K % 8 == 0\n    tiles_per_row = (N + BLK_N - 1) // BLK_N\n    zero = tir.const(0, dtype)\n\n    @T.prim_func(private=True)\n    def _func(  # pylint: disable=too-many-statements\n        var_x: T.handle,\n        var_w: T.handle,\n        var_indptr: T.handle,\n        var_o: T.handle,\n    ):\n        T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": True})\n        B = T.int32(is_size_var=True)\n        X = T.match_buffer(var_x, (B, K), dtype)\n        W = T.match_buffer(var_w, (Ne, N, K), dtype)\n        indptr = T.match_buffer(var_indptr, (Ne + 1,), \"int32\")\n        O = T.match_buffer(var_o, (B, N), dtype)\n\n        for _bx in T.thread_binding(CTA_COUNT, thread=\"blockIdx.x\"):\n            with T.sblock(\"CTA\"):\n                bx = T.axis.spatial(CTA_COUNT, _bx)\n                T.reads(indptr[:], X[:, :], W[:, :, :])\n                T.writes(O[:, :])\n                # pylint: disable=redefined-builtin\n                sum = T.sblock_alloc_buffer((2,), \"int32\", scope=\"local\")\n                row = T.sblock_alloc_buffer((2,), \"int32\", scope=\"local\")\n                cur_e = T.sblock_alloc_buffer((1,), \"int32\", scope=\"local\")\n                tile_id = T.sblock_alloc_buffer((1,), \"int32\", scope=\"local\")\n                # pylint: enable=redefined-builtin\n                sum[0] = 0\n                sum[1] = T.ceildiv(indptr[1] - indptr[0], BLK_M) * tiles_per_row\n                row[0] = 0\n                row[1] = indptr[1] - indptr[0]\n                cur_e[0] = 0\n                tile_id[0] = bx\n                while T.tvm_thread_invariant(cur_e[0] < Ne):  # pylint: disable=no-member\n                    # move to the current group\n                    while sum[1] <= tile_id[0] and cur_e[0] < Ne:\n                        cur_e[0] += 1\n                        if cur_e[0] < Ne:\n                            e: T.int32 = cur_e[0]\n                            delta: T.int32 = indptr[e + 1] - indptr[e]\n                            sum[0] = sum[1]\n                            sum[1] += T.ceildiv(delta, BLK_M) * tiles_per_row\n                            row[0] = row[1]\n                            row[1] += delta\n                    # sync threads to make sure all threads have the same tile position\n                    T.tvm_storage_sync(\"shared\")\n                    if T.tvm_thread_invariant(cur_e[0] < Ne):  # pylint: disable=no-member\n                        # fetch current tile position\n                        e: T.int32 = cur_e[0]  # type: ignore[no-redef]\n                        num_tiles: T.int32 = tile_id[0] - sum[0]\n                        m_offset: T.int32 = BLK_M * T.floordiv(num_tiles, tiles_per_row) + row[0]\n                        n_offset: T.int32 = BLK_N * T.floormod(num_tiles, tiles_per_row)\n                        with T.sblock(\"gemm\"):\n                            T.reads(\n                                row[1],\n                                X[m_offset : m_offset + BLK_M, :],\n                                W[e, n_offset : n_offset + BLK_N, :],\n                            )\n                            T.writes(\n                                O[\n                                    m_offset : m_offset + BLK_M,\n                                    n_offset : n_offset + BLK_N,\n                                ]\n                            )\n                            X_tile = T.sblock_alloc_buffer((BLK_M, K), dtype, scope=\"shared\")\n                            W_tile = T.sblock_alloc_buffer((BLK_N, K), dtype, scope=\"shared\")\n                            O_tile = T.sblock_alloc_buffer((BLK_M, BLK_N), dtype, scope=\"local\")\n                            for a0, a1 in T.grid(BLK_M, K):\n                                with T.sblock(\"X_shared\"):\n                                    i, j = T.axis.remap(\"SS\", [a0, a1])\n                                    X_tile[i, j] = T.if_then_else(\n                                        m_offset + i < row[1],\n                                        X[m_offset + i, j],\n                                        zero,\n                                    )\n                            for a0, a1 in T.grid(BLK_N, K):\n                                with T.sblock(\"W_shared\"):\n                                    i, j = T.axis.remap(\"SS\", [a0, a1])\n                                    W_tile[i, j] = T.if_then_else(\n                                        n_offset + i < N,\n                                        W[e, n_offset + i, j],\n                                        zero,\n                                    )\n                            for a0, a1, a2 in T.grid(BLK_M, BLK_N, K):\n                                with T.sblock(\"compute\"):\n                                    i, j, k = T.axis.remap(\"SSR\", [a0, a1, a2])\n                                    with T.init():\n                                        O_tile[i, j] = zero\n                                    O_tile[i, j] += X_tile[i, k] * W_tile[j, k]\n                            for a0, a1 in T.grid(BLK_M, BLK_N):\n                                with T.sblock(\"store\"):\n                                    i, j = T.axis.remap(\"SS\", [a0, a1])\n                                    if m_offset + i < row[1] and n_offset + j < N:\n                                        O[m_offset + i, n_offset + j] = O_tile[i, j]\n                    # move to next tile\n                    tile_id[0] += CTA_COUNT\n\n    def _schedule():\n        sch = s_tir.Schedule(_func)\n\n        def _cooperative_fetch(block, vec_len):\n            num_loops = len(sch.get_loops(block))\n            sch.compute_at(block, ko, preserve_unit_loops=True)\n            loops = sch.get_loops(block)[-num_loops:]\n            ty, tx, _, vec = sch.split(\n                sch.fuse(*loops),\n                factors=[TY, TX, None, vec_len],\n            )\n            sch.vectorize(vec)\n            sch.bind(ty, \"threadIdx.y\")\n            sch.bind(tx, \"threadIdx.x\")\n            if STORAGE_ALIGN:\n                sch.storage_align(block, 0, axis=1, factor=8, offset=vec_len)\n            return block\n\n        main_block = sch.get_sblock(\"compute\")\n        x, y, k = sch.get_loops(main_block)\n        ty, yi = sch.split(y, [TY, None])\n        tx, xi, vec_c = sch.split(x, [TX, None, VEC_DOT])\n        ko, ki = sch.split(k, factors=[None, BLK_K])\n        sch.reorder(ty, tx, ko, ki, yi, xi, vec_c)\n        sch.bind(ty, \"threadIdx.y\")\n        sch.bind(tx, \"threadIdx.x\")\n        sch.vectorize(vec_c)\n        if UNROLL > 0:\n            sch.annotate(tx, ann_key=\"pragma_auto_unroll_max_step\", ann_val=UNROLL)\n            sch.annotate(tx, ann_key=\"pragma_unroll_explicit\", ann_val=1)\n        l2g = sch.get_sblock(\"store\")\n        sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True)\n        _, v = sch.split(sch.get_loops(l2g)[-1], [None, VEC_O])\n        sch.vectorize(v)\n        _cooperative_fetch(sch.get_sblock(\"X_shared\"), vec_len=VEC_X)\n        _cooperative_fetch(sch.get_sblock(\"W_shared\"), vec_len=VEC_W)\n        sch.decompose_reduction(main_block, ko)\n        return sch.mod[\"main\"]\n\n    return op.tensor_ir_op(\n        _schedule(),\n        \"group_gemm\",\n        args=[x, w, indptr],\n        out=Tensor.placeholder([x.shape[0], out_features], dtype),\n    )\n\n\ndef dequantize_group_gemm(\n    x: Tensor,\n    w: Tensor,\n    scale: Tensor,\n    indptr: Tensor,\n    quantize_dtype: str,\n    indptr_dtype: str,\n    group_size: int,\n):\n    \"\"\"Group GEMM in MoE models but the weight is quantized.\n\n    Parameters\n    ----------\n    x : Tensor\n        Input tensor of shape (batch_size, in_features), where `batch_size` could be dynamic shape.\n\n    w : Tensor\n        Weight tensor of shape (num_local_experts, out_features, in_features // n), where n is the\n        number of elements per storage dtype, e.g. if the storage dtype is uint32, and the quantize\n        dtype is int4, then n is 8.\n\n    scale : Tensor\n        The scale tensor of shape (num_local_experts, out_features, in_features // group_size).\n\n    indptr : Tensor\n        Index pointer tensor of shape (num_local_experts + 1, ). `x[indptr[a] : indptr[a + 1]]` is\n        the input for the `i`-th local expert.\n\n    group_size : int\n        The number of elements in each quantization group, e.g. 32 or 128.\n\n    quantize_dtype : str\n        The quantize dtype of the weight tensor, which is usually int3, int4 or fp8, etc.\n\n    indptr_dtype : str\n        The dtype of the index pointer tensor, which can be int32 or int64.\n\n    Returns\n    -------\n    out : Tensor\n        Output tensor of shape (batch_size, out_features).\n    \"\"\"\n    (_, in_features), model_dtype = x.shape, x.dtype\n    (num_local_experts, out_features, _), storage_dtype = w.shape, w.dtype\n    quantize_dtype_bits = DataType(quantize_dtype).bits\n    num_elem_per_storage = DataType(storage_dtype).bits // quantize_dtype_bits\n    num_group = (in_features + group_size - 1) // group_size\n    num_storage = group_size // num_elem_per_storage * num_group\n\n    def _dequantize(w, s, e, i, j):\n        tir_bin_mask = tir.const((1 << quantize_dtype_bits) - 1, storage_dtype)\n        tir_max_int = tir.const((2 ** (quantize_dtype_bits - 1)) - 1, model_dtype)\n        w = w[e, i, j // num_elem_per_storage]\n        s = s[e, i, j // group_size]\n        shift = (j % num_elem_per_storage * quantize_dtype_bits).astype(storage_dtype)\n        w = tir.bitwise_and(tir.shift_right(w, shift), tir_bin_mask).astype(model_dtype)\n        return (w - tir_max_int) * s\n\n    Ne, N, K = num_local_experts, out_features, in_features\n    BLK_M, BLK_N, BLK_K = 8, 128, 32\n    TX, TY, CTA_COUNT = 8, 32, 1024\n    VEC_X, VEC_W, VEC_O, VEC_DOT = 1, 1, 1, 1\n    UNROLL = 64\n    STORAGE_ALIGN = False\n    assert BLK_K % 8 == 0\n    tiles_per_row = (N + BLK_N - 1) // BLK_N\n    zero = tir.const(0, model_dtype)\n    if indptr_dtype == \"int64\":\n        indptr = op.pad(indptr, [1, 0], \"constant\", 0)\n\n    @T.prim_func(private=True)\n    def _func(\n        var_x: T.handle,\n        w: T.Buffer((Ne, N, num_storage), storage_dtype),\n        scale: T.Buffer((Ne, N, num_group), model_dtype),\n        indptr: T.Buffer((Ne + 1,), indptr_dtype),\n        var_o: T.handle,\n    ):\n        T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": True})\n        B = T.int32(is_size_var=True)\n        X = T.match_buffer(var_x, (B, K), model_dtype)\n        O = T.match_buffer(var_o, (B, N), model_dtype)\n        for _bx in T.thread_binding(CTA_COUNT, thread=\"blockIdx.x\"):\n            with T.sblock(\"CTA\"):\n                bx = T.axis.spatial(CTA_COUNT, _bx)\n                T.reads(X[:, :], w[:, :, :], scale[:, :, :], indptr[:])\n                T.writes(O[:, :])\n                # pylint: disable=redefined-builtin\n                sum = T.sblock_alloc_buffer((2,), indptr_dtype, scope=\"local\")\n                row = T.sblock_alloc_buffer((2,), indptr_dtype, scope=\"local\")\n                cur_e = T.sblock_alloc_buffer((1,), indptr_dtype, scope=\"local\")\n                tile_id = T.sblock_alloc_buffer((1,), indptr_dtype, scope=\"local\")\n                # pylint: enable=redefined-builtin\n                sum[0] = 0\n                sum[1] = T.ceildiv(indptr[1] - indptr[0], BLK_M) * tiles_per_row\n                row[0] = 0\n                row[1] = indptr[1] - indptr[0]\n                cur_e[0] = 0\n                tile_id[0] = bx\n                while T.tvm_thread_invariant(cur_e[0] < Ne):  # pylint: disable=no-member\n                    # move to the current group\n                    while sum[1] <= tile_id[0] and cur_e[0] < Ne:\n                        cur_e[0] += 1\n                        if cur_e[0] < Ne:\n                            e = cur_e[0]\n                            delta = indptr[e + 1] - indptr[e]\n                            sum[0] = sum[1]\n                            sum[1] += T.ceildiv(delta, BLK_M) * tiles_per_row\n                            row[0] = row[1]\n                            row[1] += delta\n                    # sync threads to make sure all threads have the same tile position\n                    T.tvm_storage_sync(\"shared\")\n                    if T.tvm_thread_invariant(cur_e[0] < Ne):  # pylint: disable=no-member\n                        # fetch current tile position\n                        e = cur_e[0]  # type: ignore[no-redef]\n                        num_tiles = tile_id[0] - sum[0]\n                        m_offset = T.floordiv(num_tiles, tiles_per_row) * BLK_M + row[0]\n                        n_offset = T.floormod(num_tiles, tiles_per_row) * BLK_N\n                        with T.sblock(\"gemm\"):\n                            T.reads(\n                                row[1],\n                                X[m_offset : m_offset + BLK_M, :],\n                                w[e, n_offset : n_offset + BLK_N, :],\n                                scale[e, n_offset : n_offset + BLK_N, :],\n                            )\n                            T.writes(\n                                O[\n                                    m_offset : m_offset + BLK_M,\n                                    n_offset : n_offset + BLK_N,\n                                ]\n                            )\n                            X_tile = T.sblock_alloc_buffer((BLK_M, K), model_dtype, scope=\"shared\")\n                            W_tile = T.sblock_alloc_buffer((BLK_N, K), model_dtype, scope=\"shared\")\n                            O_tile = T.sblock_alloc_buffer((BLK_M, BLK_N), \"float32\", scope=\"local\")\n                            for a0, a1 in T.grid(BLK_M, K):\n                                with T.sblock(\"X_shared\"):\n                                    i, j = T.axis.remap(\"SS\", [a0, a1])\n                                    X_tile[i, j] = T.if_then_else(\n                                        m_offset + i < row[1],\n                                        X[m_offset + i, j],\n                                        zero,\n                                    )\n                            for a0, a1 in T.grid(BLK_N, K):\n                                with T.sblock(\"W_shared\"):\n                                    i, j = T.axis.remap(\"SS\", [a0, a1])\n                                    W_tile[i, j] = T.if_then_else(\n                                        n_offset + i < N,\n                                        _dequantize(w, scale, e, n_offset + i, j),\n                                        zero,\n                                    )\n                            for a0, a1, a2 in T.grid(BLK_M, BLK_N, K):\n                                with T.sblock(\"compute\"):\n                                    i, j, k = T.axis.remap(\"SSR\", [a0, a1, a2])\n                                    with T.init():\n                                        O_tile[i, j] = zero\n                                    O_tile[i, j] += X_tile[i, k] * W_tile[j, k]\n                            for a0, a1 in T.grid(BLK_M, BLK_N):\n                                with T.sblock(\"store\"):\n                                    i, j = T.axis.remap(\"SS\", [a0, a1])\n                                    if m_offset + i < row[1] and n_offset + j < N:\n                                        O[m_offset + i, n_offset + j] = O_tile[i, j]\n                    # move to next tile\n                    tile_id[0] += CTA_COUNT\n\n    def _schedule():\n        sch = s_tir.Schedule(_func)\n\n        def _cooperative_fetch(block, vec_len):\n            num_loops = len(sch.get_loops(block))\n            sch.compute_at(block, ko, preserve_unit_loops=True)\n            loops = sch.get_loops(block)[-num_loops:]\n            ty, tx, _, vec = sch.split(\n                sch.fuse(*loops),\n                factors=[TY, TX, None, vec_len],\n            )\n            sch.vectorize(vec)\n            sch.bind(ty, \"threadIdx.y\")\n            sch.bind(tx, \"threadIdx.x\")\n            if STORAGE_ALIGN:\n                sch.storage_align(block, 0, axis=1, factor=8, offset=vec_len)\n            return block\n\n        main_block = sch.get_sblock(\"compute\")\n        x, y, k = sch.get_loops(main_block)\n        ty, yi = sch.split(y, [TY, None])\n        tx, xi, vec_c = sch.split(x, [TX, None, VEC_DOT])\n        ko, ki = sch.split(k, factors=[None, BLK_K])\n        sch.reorder(ty, tx, ko, ki, yi, xi, vec_c)\n        sch.bind(ty, \"threadIdx.y\")\n        sch.bind(tx, \"threadIdx.x\")\n        sch.vectorize(vec_c)\n        if UNROLL > 0:\n            sch.annotate(tx, ann_key=\"pragma_auto_unroll_max_step\", ann_val=UNROLL)\n            sch.annotate(tx, ann_key=\"pragma_unroll_explicit\", ann_val=1)\n        l2g = sch.get_sblock(\"store\")\n        sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True)\n        _, v = sch.split(sch.get_loops(l2g)[-1], [None, VEC_O])\n        sch.vectorize(v)\n        _cooperative_fetch(sch.get_sblock(\"X_shared\"), vec_len=VEC_X)\n        _cooperative_fetch(sch.get_sblock(\"W_shared\"), vec_len=VEC_W)\n        sch.decompose_reduction(main_block, ko)\n        return sch.mod[\"main\"]\n\n    return op.tensor_ir_op(\n        _schedule(),\n        \"dequantize_group_gemm\",\n        args=[x, w, scale, indptr],\n        out=Tensor.placeholder([x.shape[0], out_features], model_dtype),\n    )\n"
  },
  {
    "path": "python/mlc_llm/op/moe_misc.py",
    "content": "\"\"\"Mixture of Experts operators\"\"\"\n\nfrom functools import reduce\nfrom typing import Literal, Optional, Tuple, Union\n\nimport numpy as np\nfrom tvm import te, tir\nfrom tvm.relax.frontend.nn import IntExpr, Tensor, op\nfrom tvm.script import tir as T\n\n# mypy: disable-error-code=\"attr-defined,name-defined\"\n# pylint: disable=line-too-long,too-many-locals,invalid-name\n\n\ndef moe_sum(x: Tensor, dim: int) -> Tensor:\n    \"\"\"Compute the sum of the input tensor along the given axis. It is specialized for the MoE\n    case where `x.ndim == 3` and `x.shape[1] == num_experts_per_tok (which is 2)`.\n    \"\"\"\n\n    if x.shape[1] == 1:\n        return x.reshape(x.shape[0], x.shape[2])\n\n    if x.ndim == 3 and x.shape[1] == 2:\n        return op.tensor_expr_op(\n            lambda x: te.compute(\n                (x.shape[0], x.shape[2]),\n                lambda i, j: x[i, 0, j] + x[i, 1, j],\n                name=\"sum_2\",\n            ),\n            \"sum\",\n            args=[x],\n        )\n    return op.sum(x, axis=dim)\n\n\ndef _gating_topk_init_local_top_k(k_val, dtype, local_top_k, local_top_k_index):\n    for t in range(k_val):\n        T.buffer_store(local_top_k, T.min_value(dtype), indices=[t])\n    for t in range(k_val):\n        T.buffer_store(local_top_k_index, t, indices=[-1])\n\n\ndef _gating_topk_process_value(  # pylint: disable=too-many-arguments\n    k_val, x, local_top_k, local_top_k_index, vi, vk\n):\n    if_frames = [T.If(x[vi, vk] > local_top_k[i]) for i in range(k_val)]\n    then_frames = [T.Then() for _ in range(k_val)]\n    else_frames = [T.Else() for _ in range(k_val - 1)]\n    for i in range(k_val):\n        if_frames[i].__enter__()  # pylint: disable=unnecessary-dunder-call\n        with then_frames[i]:\n            for j in range(k_val - 1, i, -1):\n                T.buffer_store(local_top_k, local_top_k[j - 1], indices=[j])\n                T.buffer_store(local_top_k_index, local_top_k_index[j - 1], indices=[j])\n            T.buffer_store(local_top_k, x[vi, vk], indices=[i])\n            T.buffer_store(local_top_k_index, vk, indices=[i])\n        if i != k_val - 1:\n            else_frames[i].__enter__()  # pylint: disable=unnecessary-dunder-call\n\n    for i in range(k_val - 1, -1, -1):\n        if i != k_val - 1:\n            else_frames[i].__exit__(None, None, None)\n        if_frames[i].__exit__(None, None, None)\n\n\ndef gating_topk(scores: Tensor, k: int) -> Tuple[Tensor, Tensor]:\n    \"\"\"Compute the top-k experts and their scores.\n\n    Parameters\n    ----------\n    scores : Tensor\n        The input tensor with shape [batch_size, num_local_experts].\n\n    k : int\n        The number of top elements to be selected, which is `num_experts_per_tok` in MoE.\n\n    Returns\n    -------\n    expert_weights: Tensor\n        The top-k expert scores with shape [batch_size, k].\n\n    expert_indices: Tensor\n        The top-k expert indices with shape [batch_size, k].\n    \"\"\"\n    (batch_size, num_local_experts), dtype = scores.shape, scores.dtype\n    index_dtype = \"int32\"\n\n    TX = 1024\n\n    def _get_topk_func(k_val: int):\n        @T.prim_func(private=True)\n        def topk_func(\n            var_x: T.handle,\n            var_out: T.handle,\n            var_out_index: T.handle,\n        ) -> None:\n            T.func_attr({\"tir.noalias\": True, \"tir.is_scheduled\": True})\n            batch_size = T.int64()\n            x = T.match_buffer(var_x, (batch_size, num_local_experts), dtype)\n            out = T.match_buffer(var_out, (batch_size, k_val), dtype)\n            out_index = T.match_buffer(var_out_index, (batch_size, k_val), index_dtype)\n            local_top_k = T.sblock_alloc_buffer((k_val,), dtype=dtype, scope=\"local\")\n            local_top_k_index = T.sblock_alloc_buffer((k_val,), dtype=index_dtype, scope=\"local\")\n            for io in T.thread_binding(0, T.ceildiv(batch_size, TX), \"blockIdx.x\"):\n                for ii in T.thread_binding(0, TX, \"threadIdx.x\"):\n                    with T.sblock(\"top_k\"):\n                        vi = T.axis.spatial(batch_size, io * TX + ii)\n                        T.where(io * TX + ii < batch_size)\n                        with T.sblock(\"init\"):\n                            _gating_topk_init_local_top_k(\n                                k_val, dtype, local_top_k, local_top_k_index\n                            )\n                        for k in range(num_local_experts):\n                            with T.sblock(\"update\"):\n                                vk = T.axis.remap(\"S\", [k])\n                                _gating_topk_process_value(\n                                    k_val, x, local_top_k, local_top_k_index, vi, vk\n                                )\n                        for j in T.unroll(k_val):\n                            with T.sblock(\"output\"):\n                                vj = T.axis.remap(\"S\", [j])\n                                out[vi, vj] = local_top_k[vj]\n                                out_index[vi, vj] = local_top_k_index[vj]\n\n        return topk_func\n\n    return op.tensor_ir_op(\n        _get_topk_func(k),\n        f\"top{k}\",\n        args=[scores],\n        out=(\n            Tensor.placeholder([batch_size, k], dtype),\n            Tensor.placeholder([batch_size, k], index_dtype),\n        ),\n    )\n\n\ndef gating_softmax_topk(  # pylint: disable=too-many-statements\n    x: Tensor, k: int, norm_topk_prob=True\n) -> Tuple[Tensor, Tensor]:\n    \"\"\"Compute the softmax score, choose the top-k experts, and returns selected scores.\n\n    Parameters\n    ----------\n    x : Tensor\n        The input tensor with shape [batch_size, num_local_experts].\n\n    k : int\n        The number of top elements to be selected, which is `num_experts_per_tok` in MoE.\n\n    norm_topk_prob : bool\n        Whether to normalize the top-k expert scores.\n\n    Returns\n    -------\n    expert_weights: Tensor\n        The top-k expert scores with shape [batch_size, k].\n\n    expert_indices: Tensor\n        The top-k expert indices with shape [batch_size, k].\n    \"\"\"\n    (batch_size, num_local_experts), dtype = x.shape, x.dtype\n    index_dtype = \"int32\"\n\n    TX = 1024\n\n    def _get_topk_softmax_norm_func(k_val: int):\n        def _nested_max(local_top_k_f32):\n            expr = local_top_k_f32[0]\n            for i in range(1, k_val):\n                expr = T.max(expr, local_top_k_f32[i])\n            return expr\n\n        def _nested_sum(local_top_k_f32, local_top_k_max):\n            expr = T.exp(local_top_k_f32[0] - local_top_k_max[0])\n            for i in range(1, k_val):\n                expr = expr + T.exp(local_top_k_f32[i] - local_top_k_max[0])\n            return expr\n\n        @T.prim_func(private=True)\n        def topk_softmax_norm_func(\n            var_x: T.handle,\n            var_out: T.handle,\n            var_out_index: T.handle,\n        ) -> None:\n            T.func_attr({\"tir.noalias\": True, \"tir.is_scheduled\": True})\n            batch_size = T.int64()\n            x = T.match_buffer(var_x, (batch_size, num_local_experts), dtype)\n            out = T.match_buffer(var_out, (batch_size, k_val), dtype)\n            out_index = T.match_buffer(var_out_index, (batch_size, k_val), index_dtype)\n            local_top_k = T.sblock_alloc_buffer((k_val,), dtype=dtype, scope=\"local\")\n            local_top_k_index = T.sblock_alloc_buffer((k_val,), dtype=index_dtype, scope=\"local\")\n            local_top_k_f32 = T.sblock_alloc_buffer((k_val,), dtype=\"float32\", scope=\"local\")\n            local_top_k_max = T.sblock_alloc_buffer((1,), dtype=\"float32\", scope=\"local\")\n            for io in T.thread_binding(0, T.ceildiv(batch_size, TX), \"blockIdx.x\"):\n                for ii in T.thread_binding(0, TX, \"threadIdx.x\"):\n                    with T.sblock(\"top_k\"):\n                        vi = T.axis.spatial(batch_size, io * TX + ii)\n                        T.where(io * TX + ii < batch_size)\n                        with T.sblock(\"init\"):\n                            _gating_topk_init_local_top_k(\n                                k_val, dtype, local_top_k, local_top_k_index\n                            )\n                        for k in range(num_local_experts):\n                            with T.sblock(\"update\"):\n                                vk = T.axis.remap(\"S\", [k])\n                                _gating_topk_process_value(\n                                    k_val, x, local_top_k, local_top_k_index, vi, vk\n                                )\n                        for j in T.unroll(k_val):\n                            with T.sblock(\"cast\"):\n                                vj = T.axis.remap(\"S\", [j])\n                                local_top_k_f32[vj] = T.cast(local_top_k[vj], \"float32\")\n                        with T.sblock(\"max\"):\n                            local_top_k_max[0] = _nested_max(local_top_k_f32)\n                        for j in T.unroll(k_val):\n                            with T.sblock(\"output\"):\n                                vj = T.axis.remap(\"S\", [j])\n                                out[vi, vj] = T.cast(\n                                    T.exp(local_top_k_f32[vj] - local_top_k_max[0])\n                                    / _nested_sum(local_top_k_f32, local_top_k_max),\n                                    dtype,\n                                )\n                                out_index[vi, vj] = local_top_k_index[vj]\n\n        return topk_softmax_norm_func\n\n    if norm_topk_prob:\n        return op.tensor_ir_op(\n            _get_topk_softmax_norm_func(k),\n            f\"top{k}_softmax\",\n            args=[x],\n            out=(\n                Tensor.placeholder([batch_size, k], dtype),\n                Tensor.placeholder([batch_size, k], index_dtype),\n            ),\n        )\n\n    expert_score = op.softmax(x.astype(\"float32\"), axis=-1).astype(dtype)\n    return gating_topk(expert_score, k)\n\n\ndef group_limited_greedy_topk(  # pylint: disable=too-many-arguments\n    scores: Tensor,  # (num_tokens, num_routed_experts)\n    top_k: int,\n    num_routed_experts: int,\n    n_group: int,\n    topk_group: int,\n    topk_method: Literal[\"group_limited_greedy\", \"noaux_tc\"],\n    num_tokens: IntExpr,\n    e_score_correction_bias: Optional[Tensor],\n) -> Tuple[Tensor, Tensor]:\n    \"\"\"Group-limited greedy top-k expert selection.\n\n    Parameters\n    ----------\n    scores : Tensor\n        The input tensor with shape [num_tokens, num_routed_experts].\n\n    top_k : int\n        The number of top elements to be selected, which is `num_experts_per_tok` in MoE.\n\n    num_routed_experts : int\n        The number of routed experts.\n\n    n_group : int\n        The number of groups.\n\n    topk_group : int\n        The number of top-k groups to be selected.\n\n    topk_method : Literal[\"group_limited_greedy\", \"noaux_tc\"]\n        The method to select the top-k groups.\n\n    num_tokens : IntExpr\n        The number of tokens.\n\n    e_score_correction_bias : Optional[Tensor]\n        The bias of the expert scores. Only available for \"noaux_tc\".\n\n    Returns\n    -------\n    expert_weights : Tensor\n        The top-k expert scores with shape [num_tokens, top_k].\n\n    expert_indices : Tensor\n        The top-k expert indices with shape [num_tokens, top_k].\n    \"\"\"\n    assert scores.dtype == \"float32\"\n    scores_for_choice = scores\n    if topk_method == \"noaux_tc\":\n        assert e_score_correction_bias is not None\n        assert e_score_correction_bias.dtype == \"float32\"\n        scores_for_choice = scores + e_score_correction_bias\n    group_size = num_routed_experts // n_group\n    if topk_method == \"noaux_tc\":\n        group_scores = op.sum(\n            gating_topk(\n                scores_for_choice.reshape(num_tokens * n_group, group_size),\n                2,\n            )[0],\n            axis=-1,\n        ).reshape(num_tokens, n_group)\n    else:\n        group_scores = op.max(\n            scores_for_choice.reshape(num_tokens * n_group, group_size), axis=-1\n        ).reshape(num_tokens, n_group)\n    group_idx = gating_topk(group_scores, topk_group)[1]  # (num_tokens, top_k_group)\n\n    @T.prim_func(private=True)\n    def group_limited_mask_scores(\n        var_scores: T.handle, var_group_idx: T.handle, var_output: T.handle\n    ):\n        T.func_attr({\"tir.noalias\": True})\n        scores = T.match_buffer(\n            var_scores, (num_tokens, num_routed_experts), dtype=scores_for_choice.dtype\n        )\n        group_idx_tir = T.match_buffer(\n            var_group_idx, (num_tokens, topk_group), dtype=group_idx.dtype\n        )\n        output = T.match_buffer(\n            var_output, (num_tokens, num_routed_experts), dtype=scores_for_choice.dtype\n        )\n        for i, j, k in T.grid(num_tokens, topk_group, group_size):\n            with T.sblock(\"mask_scores\"):\n                vi, vj, vk = T.axis.remap(\"SSS\", [i, j, k])\n                output[vi, group_idx_tir[vi, vj] * group_size + vk] = scores[\n                    vi, group_idx_tir[vi, vj] * group_size + vk\n                ]\n\n    tmp_scores = op.tensor_ir_inplace_op(\n        group_limited_mask_scores,\n        \"group_limited_mask_scores\",\n        args=[\n            scores_for_choice,\n            group_idx,\n            op.full(\n                scores_for_choice.shape,\n                float(np.finfo(\"float32\").min),\n                dtype=scores_for_choice.dtype,\n            ),\n        ],\n        inplace_indices=[2],\n        out=Tensor.placeholder(scores_for_choice.shape, scores_for_choice.dtype),\n    )\n\n    expert_weights, expert_indices = gating_topk(tmp_scores, top_k)\n    if topk_method == \"noaux_tc\":\n\n        @T.prim_func(private=True)\n        def gather_scores(var_scores: T.handle, var_expert_indices: T.handle, var_output: T.handle):\n            T.func_attr({\"tir.noalias\": True})\n            scores = T.match_buffer(\n                var_scores,\n                (num_tokens, num_routed_experts),\n                dtype=scores_for_choice.dtype,\n            )\n            expert_indices_tir = T.match_buffer(\n                var_expert_indices, (num_tokens, top_k), dtype=expert_indices.dtype\n            )\n            output = T.match_buffer(var_output, (num_tokens, top_k), dtype=scores_for_choice.dtype)\n            for i, j in T.grid(num_tokens, top_k):\n                with T.sblock(\"gather_scores\"):\n                    vi, vj = T.axis.remap(\"SS\", [i, j])\n                    output[vi, vj] = scores[vi, expert_indices_tir[vi, vj]]\n\n        expert_weights = op.tensor_ir_op(\n            gather_scores,\n            \"gather_scores\",\n            args=[scores, expert_indices],\n            out=Tensor.placeholder((num_tokens, top_k), scores_for_choice.dtype),\n        )\n    return expert_weights, expert_indices\n\n\ndef moe_cumsum(expert_indices: Tensor, num_local_experts: int) -> Tensor:\n    \"\"\"An operator that returns the cumsum array in MoE.\n\n    The input `expert_indices` of shape [batch_size, experts_per_tok] indicates the indices of\n    the activated experts for each instance in a batch. This operator first converts it to\n    `expert_mask`, a boolean mask with shape [batch_size, num_local_experts], and then computes\n    cumsum over the transpose-then-flattened array of `expert_mask`.\n\n    A position `(e, b)` in the result `cumsum`, where `e` is the expert id and `b` is the batch id,\n    indicates a shuffling plan that moves the `b`-th instance that ensures the inputs to the `e`-th\n    expert is contiguous.\n\n    Parameters\n    ----------\n    expert_indices : Tensor\n        The topk indices with shape [batch_size, experts_per_tok], int32, where\n        `experts_per_tok` is the number of activated experts.\n\n    num_local_experts : int\n        The number of totally experts.\n\n    Returns\n    -------\n    cumsum: Tensor\n        The cumsum result with shape [num_local_experts * batch_size], int32.\n\n    Example\n    -------\n    Suppose `batch_size` is 4, `experts_per_tok` is 2, the total number of experts is 6, and\n    `expert_indices` is the 2D tensor below:\n\n        [\n            [0, 1],\n            [1, 2],\n            [3, 4],\n            [2, 5],\n        ]\n\n    , then the `expert_mask` is a tensor of shape [batch_size, num_local_experts] below:\n\n        [\n            [1, 1, 0, 0, 0, 0],\n            [0, 1, 1, 0, 0, 0],\n            [0, 0, 0, 1, 1, 0],\n            [0, 0, 1, 0, 0, 1],\n        ]\n\n    . The result cumsum of the transposed `expert_mask` is a flattened version of 2D tensor below:\n\n        [\n            [1, 1, 1, 1],\n            [2, 3, 3, 3],\n            [3, 4, 4, 5],\n            [5, 5, 6, 6],\n            [6, 6, 7, 7],\n            [7, 7, 7, 8],\n        ]\n    \"\"\"\n    batch_size, experts_per_tok = expert_indices.shape\n    expert_mask = (\n        op.tensor_expr_op(  # pylint: disable=too-many-function-args\n            lambda expert_indices: te.compute(\n                (batch_size, num_local_experts),\n                lambda i, j: tir.expr.Select(\n                    reduce(\n                        tir.Or,\n                        [expert_indices[i, k] == j for k in range(experts_per_tok)],\n                    ),\n                    true_value=tir.const(1, \"int32\"),\n                    false_value=tir.const(0, \"int32\"),\n                ),\n            ),\n            \"expert_mask\",\n            args=[expert_indices],\n        )\n        .permute_dims(1, 0)\n        .reshape(batch_size * num_local_experts)\n    )\n\n    return op.cumsum(expert_mask, axis=0, exclusive=False, dtype=\"int32\")\n\n\ndef get_indices(cumsum: Tensor, expert_indices: Tensor) -> Tuple[Tensor, Tensor]:\n    \"\"\"Returns a 1D tensor of indices that represents the shuffling plan for each instance in a\n    batch, so that the inputs to each experts are contiguous and the indices for reverse permutation\n    (scatter) to the original order.\n\n    If `reverse_indices[i] = (b, j)`, it means the `b`-th instance in the batch should be moved to the\n    `i`-th position in shuffling, and `j` doesn not matter only meaning `expert_indices[b, j]`\n    corresponds to the expert at position `i` in the shuffling plan. We also compute\n    `token_indices[i] = b` so that we can use `relax.op.take` for shuffling.\n\n    Effectively it is equivalent to the following Python code:\n\n    .. code-block:: python\n\n        for b in range(batch_size):\n            for j in range(experts_per_tok):\n                e = expert_indices[b, j]\n                reverse_indices[cumsum[e * batch_size + b] - 1] = b * experts_per_tok + j\n                token_indices[cumsum[e * batch_size + b] - 1\n\n    Parameters\n    ----------\n    cumsum : Tensor\n        A flattened 1D tensor whose original shape is [experts_per_tok, batch_size].\n\n    expert_indices : Tensor\n        The indices of the experts with shape [batch_size, experts_per_tok].\n\n    Returns\n    -------\n    reverse_indices : Tensor\n        The indices for scattering with shape [batch_size * experts_per_tok].\n\n    token_indices : Tensor\n        The indices for shuffling with shape [batch_size * experts_per_tok].\n    \"\"\"\n    TX = 1024\n    batch_size, experts_per_tok = expert_indices.shape\n\n    @T.prim_func(private=True)\n    def _func(\n        var_cumsum: T.handle,\n        var_expert_indices: T.handle,\n        var_reverse_indices: T.handle,\n        var_token_indices: T.handle,\n    ):\n        T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": True})\n        batch_size = T.SizeVar(\"batch_size\", \"int32\")\n        cumsum_len = T.SizeVar(\"cumsum_len\", \"int32\")  # [experts_per_tok * batch_size]\n        cumsum = T.match_buffer(var_cumsum, [cumsum_len], \"int32\")\n        expert_indices = T.match_buffer(var_expert_indices, [batch_size, experts_per_tok], \"int32\")\n        reverse_indices = T.match_buffer(\n            var_reverse_indices, [batch_size * experts_per_tok], \"int32\"\n        )\n        token_indices = T.match_buffer(var_token_indices, [batch_size * experts_per_tok], \"int32\")\n        for bj_o in T.thread_binding(0, T.ceildiv(batch_size * experts_per_tok, TX), \"blockIdx.x\"):\n            for bj_i in T.thread_binding(0, TX, \"threadIdx.x\"):\n                with T.sblock(\"indices\"):\n                    T.reads(expert_indices[:, :], cumsum[:])\n                    T.writes(reverse_indices[:], token_indices[:])\n                    if bj_o * TX + bj_i < batch_size * experts_per_tok:\n                        b: T.int32 = T.floordiv(bj_o * TX + bj_i, experts_per_tok)\n                        j: T.int32 = T.floormod(bj_o * TX + bj_i, experts_per_tok)\n                        e: T.int32 = expert_indices[b, j]\n                        reverse_indices[cumsum[e * batch_size + b] - 1] = b * experts_per_tok + j\n                        token_indices[cumsum[e * batch_size + b] - 1] = b\n\n    return op.tensor_ir_op(\n        _func,\n        \"get_indices\",\n        args=[cumsum, expert_indices],\n        out=[Tensor.placeholder([batch_size * experts_per_tok], \"int32\") for _ in range(2)],\n    )\n\n\ndef get_indptr(\n    cumsum: Tensor,\n    num_local_experts: int,\n    batch_size: Union[int, tir.Var],\n    inclusive: bool,\n    out_dtype: str,\n) -> Tensor:\n    \"\"\"Extract the `indptr` array from MoE cumsum array. The MoE cumsum array is a flattened tensor\n    whose original shape is [num_local_experts, batch_size], and the `indptr` array is a 1D tensor\n    of length `num_local_experts + 1`. The range `[indptr[i], indptr[i + 1])` indicates instances in\n    the batch that corresponds to the `i`-th expert.\n\n    Effectively, this operator is equivalent to the following numpy code:\n\n    .. code-block:: python\n\n        indptr = np.zeros(num_local_experts + 1, dtype=np.int32)\n        indptr[0] = 0\n        for i in range(1, num_local_experts + 1):\n            indptr[i] = cumsum[i * batch_size - 1]\n        return indptr\n\n    Parameters\n    ----------\n    cumsum : Tensor\n        The prefix sum of the sparse array with shape [batch_size * num_local_experts], int32.\n\n    num_local_experts : int\n        The number of experts.\n\n    batch_size : int | tir.Var\n        The batch size. Note that the batch size here refers to `batch_size * seq_len` in MoE,\n        and we name is `batch_size` for simplicity here only because the two dimensions are fused\n        in Mixtral.\n\n    inclusive : bool\n        Whether to compute inclusive or exclusive prefix sum as the indptr. If `inclusive` is False,\n        the 0-th element of the `indptr` array, which always equals to 0, will be omitted.\n\n    out_dtype : str\n        The output dtype.\n\n    Returns\n    -------\n    indptr : Tensor\n        The `indptr` array with shape [num_local_experts + 1] if `inclusive` is True, otherwise\n        [num_local_experts]. The `indptr` array is of type `out_dtype`.\n    \"\"\"\n\n    out_shape = [num_local_experts if inclusive else num_local_experts + 1]\n\n    @T.prim_func(private=True)\n    def _func_exclusive(var_cumsum: T.handle, var_indptr: T.handle, batch_size: T.int64):\n        T.func_attr({\"tir.noalias\": True})\n        cumsum = T.match_buffer(var_cumsum, shape=[batch_size * num_local_experts], dtype=\"int32\")\n        indptr = T.match_buffer(var_indptr, shape=out_shape, dtype=out_dtype)\n        for vi in T.serial(0, out_shape[0]):\n            with T.sblock(\"indptr\"):\n                i = T.axis.spatial(out_shape[0], vi)\n                indptr[i] = T.Select(i > 0, cumsum[i * batch_size - 1], T.int32(0))\n\n    @T.prim_func(private=True)\n    def _func_inclusive(var_cumsum: T.handle, var_indptr: T.handle, batch_size: T.int64):\n        T.func_attr({\"tir.noalias\": True})\n        cumsum = T.match_buffer(var_cumsum, shape=[batch_size * num_local_experts], dtype=\"int32\")\n        indptr = T.match_buffer(var_indptr, shape=out_shape, dtype=out_dtype)\n        for vi in T.serial(0, out_shape[0]):\n            with T.sblock(\"indptr\"):\n                i = T.axis.spatial(out_shape[0], vi)\n                indptr[i] = cumsum[(i + 1) * batch_size - 1]\n\n    assert cumsum.ndim == 1\n    return op.tensor_ir_op(\n        _func_inclusive if inclusive else _func_exclusive,\n        \"get_expert_instance_indptr\",\n        args=[cumsum, batch_size],  # type: ignore[list-item]\n        out=Tensor.placeholder(out_shape, out_dtype),\n    )\n\n\ndef scatter_output(x: Tensor, indices: Tensor) -> Tensor:\n    \"\"\"Scatter the output of MoE experts back to the original positions.\n\n    Parameters\n    ----------\n    x : Tensor\n        The output of MoE experts with shape [batch_size * num_experts_per_tok, hidden_size].\n\n    indices : Tensor\n        The indices of the experts with shape [batch_size * num_experts_per_tok].\n\n    Returns\n    -------\n    out : Tensor\n        The output of MoE experts with shape [batch_size * num_experts_per_tok, hidden_size].\n    \"\"\"\n    dtype = x.dtype\n    _, hidden_size = x.shape\n\n    @T.prim_func(private=True)\n    def _func(var_x: T.handle, var_indices: T.handle, var_out: T.handle):\n        T.func_attr({\"tir.noalias\": True})\n        indices_len = T.int64()\n        x = T.match_buffer(var_x, [indices_len, hidden_size], dtype)\n        indices = T.match_buffer(var_indices, [indices_len], \"int32\")\n        out = T.match_buffer(var_out, [indices_len, hidden_size], dtype)\n        for i in T.serial(0, indices_len):\n            for j in T.serial(0, hidden_size):\n                with T.sblock(\"scatter\"):\n                    vi, vj = T.axis.remap(\"SS\", [i, j])\n                    out[indices[vi], vj] = x[vi, vj]\n\n    return op.tensor_ir_op(\n        _func,\n        \"scatter_output\",\n        args=[x, indices],\n        out=Tensor.placeholder(x.shape, dtype),\n    )\n"
  },
  {
    "path": "python/mlc_llm/op/mrope.py",
    "content": "\"\"\"Utilities for Multimodal Rotary Position Embeddings (MRoPE).\"\"\"\n\nfrom __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import List, Optional, Sequence, Tuple\n\nimport numpy as np\nfrom tvm import te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import Tensor, op\n\n\ndef _rotate_half(x: Tensor) -> Tensor:\n    \"\"\"Rotate the last dimension of ``x`` by swapping pairs.\"\"\"\n\n    x1, x2 = op.split(x, 2, axis=-1)\n    return op.concat([op.negative(x2), x1], dim=-1)\n\n\ndef _repeat_mrope_section(section: Sequence[int]) -> Tuple[int, ...]:\n    if not section:\n        raise ValueError(\"mrope_section must not be empty.\")\n    if any(s <= 0 for s in section):\n        raise ValueError(f\"All mrope_section entries must be positive, got {section}.\")\n    return tuple(section) * 2\n\n\ndef _split_indices_from_sizes(sizes: Sequence[int]) -> List[int]:\n    indices: List[int] = []\n    running = 0\n    # Drop the final cumulative sum so split() keeps the last chunk.\n    for size in sizes[:-1]:\n        running += size\n        indices.append(running)\n    return indices\n\n\ndef _reorder_cos_sin(\n    tensor: Tensor,\n    split_sizes: Sequence[int],\n) -> Tensor:\n    \"\"\"Reorder cos/sin tensors so the head dimension follows T/H/W repeating sections.\"\"\"\n\n    if not split_sizes:\n        raise ValueError(\"split_sizes must not be empty.\")\n    split_points = _split_indices_from_sizes(split_sizes)\n    # relax.op.split returns a Python tuple, so we can iterate directly.\n    sections = op.split(tensor, indices_or_sections=split_points, axis=-1)\n    reordered = []\n    for idx, chunk in enumerate(sections):\n        axis_selector = nn.Tensor.from_const(np.array([idx % 3], dtype=\"int32\"))\n        axis_slice = op.take(chunk, axis_selector, axis=0)\n        reordered.append(nn.op.squeeze(axis_slice, 0))\n    return op.concat(reordered, dim=-1)\n\n\nclass MultimodalRotaryEmbedding(nn.Module):\n    \"\"\"Generate cosine/sine tables for multimodal rotary embeddings.\"\"\"\n\n    def __init__(\n        self,\n        head_dim: int,\n        theta: float,\n        mrope_section: Sequence[int],\n        attention_scaling: float = 1.0,\n    ) -> None:\n        if head_dim % 2 != 0:\n            raise ValueError(f\"head_dim must be even for RoPE, got {head_dim}.\")\n        self.head_dim = head_dim\n        self.theta = theta\n        self.attention_scaling = attention_scaling\n        self.mrope_section = tuple(mrope_section)\n        self._inv_freq = 1.0 / (\n            theta ** (np.arange(0, head_dim, 2, dtype=\"float32\") / np.float32(head_dim))\n        )\n\n    def forward(self, reference: Tensor, position_ids: Tensor) -> Tuple[Tensor, Tensor]:\n        \"\"\"Return ``(cos, sin)`` with shape ``(3, batch, seq, head_dim)``.\"\"\"\n        if len(position_ids.shape) != 3:\n            raise ValueError(\n                \"position_ids must be rank-3 with either \"\n                \"(batch, seq, 3) or (3, batch, seq) layout, \"\n                f\"got shape {position_ids.shape}.\"\n            )\n        if isinstance(position_ids.shape[0], int) and position_ids.shape[0] == 3:\n            batch_size, seq_len = position_ids.shape[1], position_ids.shape[2]\n            pos_tensor = op.reshape(position_ids, (3, batch_size, 1, seq_len))\n        elif isinstance(position_ids.shape[-1], int) and position_ids.shape[-1] == 3:\n            batch_size, seq_len = position_ids.shape[0], position_ids.shape[1]\n            permuted_pos = op.permute_dims(position_ids, axes=[2, 0, 1])\n            pos_tensor = op.reshape(permuted_pos, (3, batch_size, 1, seq_len))\n        else:\n            raise ValueError(\n                \"position_ids must have exactly one static dimension of size 3, \"\n                f\"got shape {position_ids.shape}.\"\n            )\n\n        dtype = reference.dtype\n        inv_freq_tensor = nn.Tensor.from_const(self._inv_freq.reshape(1, 1, -1, 1))\n        inv_freq_tensor = op.broadcast_to(inv_freq_tensor, (3, batch_size, self._inv_freq.size, 1))\n\n        freqs = op.matmul(inv_freq_tensor.astype(\"float32\"), pos_tensor.astype(\"float32\"))\n        freqs = op.permute_dims(freqs, axes=[0, 1, 3, 2])\n        emb = op.concat([freqs, freqs], dim=-1)\n\n        def _apply_trig(func_name: str) -> Tensor:\n            def compute(x: te.Tensor):\n                return te.compute(\n                    x.shape,\n                    lambda *indices: getattr(tir, func_name)(x[indices]),\n                    name=f\"mrope_{func_name}\",\n                )\n\n            return op.tensor_expr_op(compute, f\"mrope_{func_name}\", [emb])\n\n        cos = _apply_trig(\"cos\") * self.attention_scaling\n        sin = _apply_trig(\"sin\") * self.attention_scaling\n        return cos.astype(dtype), sin.astype(dtype)\n\n\ndef apply_multimodal_rotary_pos_emb(  # pylint: disable=too-many-arguments\n    q: Tensor,\n    k: Tensor,\n    cos: Tensor,\n    sin: Tensor,\n    mrope_section: Sequence[int],\n    unsqueeze_dim: int = 2,\n) -> Tuple[Tensor, Tensor]:\n    \"\"\"Apply multimodal rotary embedding to query and key tensors.\"\"\"\n\n    split_sizes = _repeat_mrope_section(mrope_section)\n    reordered_cos = _reorder_cos_sin(cos, split_sizes)\n    reordered_sin = _reorder_cos_sin(sin, split_sizes)\n    cos_term = op.unsqueeze(reordered_cos, dim=unsqueeze_dim)\n    sin_term = op.unsqueeze(reordered_sin, dim=unsqueeze_dim)\n    cos_term = cos_term.astype(q.dtype)\n    sin_term = sin_term.astype(q.dtype)\n    q_embed = op.add(op.multiply(q, cos_term), op.multiply(_rotate_half(q), sin_term))\n    k_embed = op.add(op.multiply(k, cos_term), op.multiply(_rotate_half(k), sin_term))\n    return q_embed, k_embed\n\n\n@dataclass\nclass VisionPositionMetadata:\n    \"\"\"Metadata required to build multimodal position IDs.\"\"\"\n\n    vision_start_token_id: int\n    image_token_id: int\n    video_token_id: int\n    spatial_merge_size: int\n    tokens_per_second: float\n\n    def merged_hw(self, height: int, width: int) -> Tuple[int, int]:\n        \"\"\"Return merged height/width after applying ``spatial_merge_size``.\"\"\"\n\n        if height % self.spatial_merge_size != 0 or width % self.spatial_merge_size != 0:\n            raise ValueError(\n                \"Image or video grid is not divisible by spatial_merge_size \"\n                f\"(got h={height}, w={width}, merge={self.spatial_merge_size}).\"\n            )\n        return height // self.spatial_merge_size, width // self.spatial_merge_size\n\n\ndef _text_chunk(length: int, offset: int) -> np.ndarray:\n    \"\"\"Create a text-position chunk with a shared scalar offset for T/H/W axes.\"\"\"\n\n    if length <= 0:\n        return np.zeros((3, 0), dtype=np.int64)\n    seq: np.ndarray = np.arange(length, dtype=np.int64)\n    chunk = np.broadcast_to(seq.reshape(1, -1), (3, length))\n    return chunk + offset\n\n\ndef _grid_chunk(  # pylint: disable=too-many-arguments\n    grid_t: int,\n    grid_h: int,\n    grid_w: int,\n    offset: int,\n    tokens_per_second: float,\n    second_per_grid: float,\n) -> np.ndarray:\n    if grid_t <= 0 or grid_h <= 0 or grid_w <= 0:\n        raise ValueError(\n            f\"Invalid grid shape t={grid_t}, h={grid_h}, w={grid_w} for multimodal positions.\"\n        )\n    time_axis = (np.arange(grid_t, dtype=np.float32) * second_per_grid * tokens_per_second).astype(\n        np.int64\n    )\n    t_index = np.repeat(time_axis, grid_h * grid_w)\n    h_index = np.tile(np.repeat(np.arange(grid_h, dtype=np.int64), grid_w), grid_t)\n    w_index = np.tile(np.tile(np.arange(grid_w, dtype=np.int64), grid_h), grid_t)\n    stacked = np.stack([t_index, h_index, w_index])\n    return stacked + offset\n\n\ndef _find_token_index(tokens: Sequence[int], token_id: int, start: int) -> int:\n    for idx in range(start, len(tokens)):\n        if tokens[idx] == token_id:\n            return idx\n    return len(tokens)\n\n\ndef _next_chunk_offset(chunks: Sequence[np.ndarray]) -> int:\n    if not chunks:\n        return 0\n    return int(chunks[-1].max()) + 1\n\n\ndef _count_vision_items(\n    token_array: np.ndarray,\n    vision_start_token_id: int,\n    image_token_id: int,\n    video_token_id: int,\n) -> Tuple[int, int]:\n    vision_starts = np.where(token_array == vision_start_token_id)[0]\n    valid_starts = vision_starts[vision_starts + 1 < token_array.shape[0]]\n    following_tokens = token_array[valid_starts + 1]\n    image_count = int(np.sum(following_tokens == image_token_id))\n    video_count = int(np.sum(following_tokens == video_token_id))\n    return image_count, video_count\n\n\ndef _next_vision_block(\n    tokens: Sequence[int],\n    start: int,\n    meta: VisionPositionMetadata,\n    has_images: bool,\n    has_videos: bool,\n) -> Tuple[str, int]:\n    sentinel = len(tokens) + 1\n    image_end = _find_token_index(tokens, meta.image_token_id, start) if has_images else sentinel\n    video_end = _find_token_index(tokens, meta.video_token_id, start) if has_videos else sentinel\n    if image_end < video_end:\n        return \"image\", image_end\n    return \"video\", video_end\n\n\ndef _load_grid_for_block(  # pylint: disable=too-many-arguments\n    block_kind: str,\n    image_grid_thw: Optional[np.ndarray],\n    video_grid_thw: Optional[np.ndarray],\n    second_per_grid_ts: Optional[np.ndarray],\n    image_index: int,\n    video_index: int,\n) -> Tuple[int, int, int, float, int, int]:\n    if block_kind == \"image\":\n        if image_grid_thw is None:\n            raise ValueError(\"Image grids are required for sequences with image tokens.\")\n        grid_t, grid_h, grid_w = image_grid_thw[image_index]\n        return int(grid_t), int(grid_h), int(grid_w), 0.0, image_index + 1, video_index\n\n    if video_grid_thw is None:\n        raise ValueError(\"Video grids are required for sequences with video tokens.\")\n    grid_t, grid_h, grid_w = video_grid_thw[video_index]\n    second_per_grid = (\n        float(second_per_grid_ts[video_index]) if second_per_grid_ts is not None else 1.0\n    )\n    return int(grid_t), int(grid_h), int(grid_w), second_per_grid, image_index, video_index + 1\n\n\ndef _build_sequence_position_ids(  # pylint: disable=too-many-arguments,too-many-locals\n    input_tokens: Sequence[int],\n    meta: VisionPositionMetadata,\n    image_grid_thw: Optional[np.ndarray],\n    video_grid_thw: Optional[np.ndarray],\n    second_per_grid_ts: Optional[np.ndarray],\n    image_index: int,\n    video_index: int,\n) -> Tuple[np.ndarray, int, int, int]:\n    token_array = np.asarray(input_tokens, dtype=np.int64)\n    image_count, video_count = _count_vision_items(\n        token_array,\n        vision_start_token_id=meta.vision_start_token_id,\n        image_token_id=meta.image_token_id,\n        video_token_id=meta.video_token_id,\n    )\n    if image_count > 0 and image_grid_thw is None:\n        raise ValueError(\"Image grids are required for sequences with image tokens.\")\n    if video_count > 0 and video_grid_thw is None:\n        raise ValueError(\"Video grids are required for sequences with video tokens.\")\n\n    llm_pos_ids_list: List[np.ndarray] = []\n    start = 0\n    remain_images = image_count\n    remain_videos = video_count\n    for _ in range(image_count + video_count):\n        block_kind, block_end = _next_vision_block(\n            tokens=input_tokens,\n            start=start,\n            meta=meta,\n            has_images=remain_images > 0,\n            has_videos=remain_videos > 0,\n        )\n        (\n            grid_t,\n            grid_h,\n            grid_w,\n            second_per_grid,\n            image_index,\n            video_index,\n        ) = _load_grid_for_block(\n            block_kind=block_kind,\n            image_grid_thw=image_grid_thw,\n            video_grid_thw=video_grid_thw,\n            second_per_grid_ts=second_per_grid_ts,\n            image_index=image_index,\n            video_index=video_index,\n        )\n        if block_kind == \"image\":\n            remain_images -= 1\n        else:\n            remain_videos -= 1\n\n        llm_grid_h, llm_grid_w = meta.merged_hw(grid_h, grid_w)\n        text_len = block_end - start\n        text_offset = _next_chunk_offset(llm_pos_ids_list)\n        llm_pos_ids_list.append(_text_chunk(text_len, text_offset))\n        grid_offset = text_offset + text_len\n        llm_pos_ids_list.append(\n            _grid_chunk(\n                grid_t=grid_t,\n                grid_h=llm_grid_h,\n                grid_w=llm_grid_w,\n                offset=grid_offset,\n                tokens_per_second=meta.tokens_per_second,\n                second_per_grid=second_per_grid,\n            )\n        )\n        start = block_end + grid_t * llm_grid_h * llm_grid_w\n\n    if start < len(input_tokens):\n        tail_len = len(input_tokens) - start\n        tail_offset = _next_chunk_offset(llm_pos_ids_list)\n        llm_pos_ids_list.append(_text_chunk(tail_len, tail_offset))\n\n    if not llm_pos_ids_list:\n        empty_positions: np.ndarray = np.zeros((3, 0), dtype=np.int64)\n        return empty_positions, 0, image_index, video_index\n    llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)\n    delta = int(llm_positions.max()) + 1 - len(input_tokens)\n    return llm_positions, delta, image_index, video_index\n\n\ndef _text_only_position_ids(\n    input_ids: np.ndarray,\n    attention_mask: Optional[np.ndarray],\n) -> Tuple[np.ndarray, np.ndarray]:\n    batch, seq_len = input_ids.shape\n    if attention_mask is None:\n        base: np.ndarray = np.arange(seq_len, dtype=np.int64).reshape(1, 1, -1)\n        tiled = np.broadcast_to(base, (3, batch, seq_len))\n        return tiled, np.zeros((batch, 1), dtype=np.int64)\n\n    position = attention_mask.cumsum(axis=-1) - 1\n    position = np.where(attention_mask == 0, 1, position)\n    position = np.expand_dims(position, axis=0).repeat(3, axis=0)\n    max_pos = position.max(axis=0, keepdims=False).max(axis=-1, keepdims=True)\n    delta = (max_pos + 1 - seq_len).astype(np.int64)\n    return position.astype(np.int64), delta\n\n\ndef get_mrope_position_ids(  # pylint: disable=too-many-arguments,too-many-locals\n    input_ids: np.ndarray,\n    meta: VisionPositionMetadata,\n    attention_mask: Optional[np.ndarray] = None,\n    image_grid_thw: Optional[np.ndarray] = None,\n    video_grid_thw: Optional[np.ndarray] = None,\n    second_per_grid_ts: Optional[np.ndarray] = None,\n) -> Tuple[np.ndarray, np.ndarray]:\n    \"\"\"Generate 3D position IDs and deltas following Hugging Face Qwen2.5-VL.\"\"\"\n\n    input_ids = np.asarray(input_ids, dtype=np.int64)\n    batch, seq_len = input_ids.shape\n    position_ids = np.ones((3, batch, seq_len), dtype=np.int64)\n\n    attention = None\n    if attention_mask is not None:\n        attention_mask = np.asarray(attention_mask, dtype=np.int64)\n        if attention_mask.shape != input_ids.shape:\n            raise ValueError(\n                \"attention_mask shape must match input_ids shape: \"\n                f\"{attention_mask.shape} vs {input_ids.shape}\"\n            )\n        attention = attention_mask.astype(bool)\n\n    image_grid_thw = None if image_grid_thw is None else np.asarray(image_grid_thw, dtype=np.int64)\n    video_grid_thw = None if video_grid_thw is None else np.asarray(video_grid_thw, dtype=np.int64)\n    if second_per_grid_ts is not None:\n        second_per_grid_ts = np.asarray(second_per_grid_ts, dtype=np.float32)\n\n    contains_image_tokens = bool(np.any(input_ids == meta.image_token_id))\n    contains_video_tokens = bool(np.any(input_ids == meta.video_token_id))\n    if contains_image_tokens and image_grid_thw is None:\n        raise ValueError(\"image_grid_thw must be provided when image tokens exist in input_ids.\")\n    if contains_video_tokens and video_grid_thw is None:\n        raise ValueError(\"video_grid_thw must be provided when video tokens exist in input_ids.\")\n    if (\n        second_per_grid_ts is not None\n        and video_grid_thw is not None\n        and second_per_grid_ts.shape[0] != video_grid_thw.shape[0]\n    ):\n        raise ValueError(\n            \"second_per_grid_ts length must match number of video grids \"\n            f\"({second_per_grid_ts.shape[0]} vs {video_grid_thw.shape[0]}).\"\n        )\n\n    if not (contains_image_tokens or contains_video_tokens):\n        return _text_only_position_ids(input_ids, attention_mask)\n\n    image_index = 0\n    video_index = 0\n    deltas: List[int] = []\n\n    for batch_idx in range(batch):\n        tokens = input_ids[batch_idx]\n        if attention is not None:\n            tokens = tokens[attention[batch_idx]]\n        token_values = np.asarray(tokens, dtype=np.int64).tolist()\n        input_tokens: List[int] = [int(token) for token in token_values]\n        if not input_tokens:\n            deltas.append(0)\n            continue\n\n        llm_positions, delta, image_index, video_index = _build_sequence_position_ids(\n            input_tokens=input_tokens,\n            meta=meta,\n            image_grid_thw=image_grid_thw,\n            video_grid_thw=video_grid_thw,\n            second_per_grid_ts=second_per_grid_ts,\n            image_index=image_index,\n            video_index=video_index,\n        )\n        if attention is not None:\n            position_ids[:, batch_idx, attention[batch_idx]] = llm_positions\n        else:\n            position_ids[:, batch_idx, :] = llm_positions\n        deltas.append(delta)\n\n    delta_array = np.asarray(deltas, dtype=np.int64).reshape(batch, 1)\n    return position_ids, delta_array\n"
  },
  {
    "path": "python/mlc_llm/op/pipeline_parallel.py",
    "content": "\"\"\"Operators for pipeline parallelism.\"\"\"\n\nfrom typing import List\n\nfrom tvm import relax\nfrom tvm.relax.frontend.nn import Tensor, op\n\n\ndef pipeline_stage_boundary(*tensors: Tensor) -> List[Tensor]:\n    \"\"\"Pipeline parallelism stage boundary mark operator in MLC.\n\n    Parameters\n    ----------\n    tensors : Tensor\n        The tensors to be passed to the next stage.\n\n    Returns\n    -------\n    tensors : List[Tensor]\n        The list of input tensors passed to the next stage.\n    \"\"\"\n    # pylint: disable=protected-access\n    return op.wrap_nested(\n        relax.call_pure_packed(\n            \"mlc.pipeline_parallel_stage_boundary\",\n            *[tensor._expr for tensor in tensors],\n            sinfo_args=(\n                tensors[0]._expr.struct_info\n                if len(tensors) == 1\n                else relax.TupleStructInfo([tensor._expr.struct_info for tensor in tensors])\n            ),\n        ),\n        name=\"pipeline_stage_boundary\",\n    )\n    # pylint: enable=protected-access\n"
  },
  {
    "path": "python/mlc_llm/op/top_p_pivot.py",
    "content": "\"\"\"Operators for choosing the pivot to cut-off top-p percentile\"\"\"\n\nimport tvm\nfrom tvm.script import tir as T\n\nfrom mlc_llm.support.max_thread_check import get_max_num_threads_per_block\n\n# mypy: disable-error-code=\"attr-defined,valid-type,name-defined\"\n# pylint: disable=too-many-locals,invalid-name,too-many-arguments,unnecessary-lambda\n# pylint: disable=too-many-statements,line-too-long,too-many-nested-blocks,too-many-branches\n\n\ndef top_p_pivot(pN, target: tvm.target.Target):\n    \"\"\"Top-p pivot function. This function finds the pivot to cut-off top-p percentile.\n\n    A valide pivot should satisfy the following conditions:\n    - lsum >= top_p\n    - top_p > lsum - cmin * lmin\n    where lsum is the sum of elements that are larger or equal to the pivot,\n    lmin is the minimum elements that is larger or equal to the pivot,\n    cmin is the count of elements that are equal to lmin,\n\n    Parameters\n    ----------\n    prob:\n        The probability vector\n\n    top_p_arr:\n        The top-p threshold\n\n    init_pivots:\n        The initial pivot candidates\n\n    final_pivot:\n        The final pivot to cut-off top-p percentile\n\n    final_lsum:\n        The final sum of the values after top-p filtering.\n    \"\"\"\n    TX = 1024\n    K = 32\n    eps_LR = 1e-7\n\n    max_num_threads_per_block = get_max_num_threads_per_block(target)\n    TX = min(TX, max_num_threads_per_block)\n\n    def _var(dtype=\"int32\"):\n        return T.sblock_alloc_buffer((1,), dtype, scope=\"local\")\n\n    def valid(lsum, lmin, cmin, top_p):\n        return tvm.tir.all(lsum >= top_p, top_p > lsum - cmin * lmin)\n\n    # fmt: off\n    @T.prim_func(private=True)\n    def _func(\n        var_prob: T.handle,\n        var_top_p_arr: T.handle,\n        var_init_pivots: T.handle,\n        var_final_pivot: T.handle,\n        var_final_lsum: T.handle,\n    ):\n        T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": True})\n        B = T.int32(is_size_var=True)\n        N = T.int32(is_size_var=True)\n        prob = T.match_buffer(var_prob, (B, N,), \"float32\")\n        top_p_arr = T.match_buffer(var_top_p_arr, (B,), dtype=\"float32\")\n        init_pivots = T.match_buffer(var_init_pivots, (B, pN), \"float32\")\n        final_pivot = T.match_buffer(var_final_pivot, (B,), \"float32\")\n        final_lsum = T.match_buffer(var_final_lsum, (B,), \"float32\")\n\n        with T.sblock(\"kernel\"):\n            pivot = T.sblock_alloc_buffer((pN,), \"float32\", scope=\"local\")\n            top_p = _var(\"float32\")\n\n            L = T.sblock_alloc_buffer((1,), \"float32\", scope=\"shared\")\n            R = T.sblock_alloc_buffer((1,), \"float32\", scope=\"shared\")\n            L_local = _var(\"float32\")\n            R_local = _var(\"float32\")\n\n            q = _var(\"float32\")\n            lsum = T.sblock_alloc_buffer((pN,), \"float32\", scope=\"local\")\n            lmin_broadcast = T.sblock_alloc_buffer((1), \"float32\", scope=\"shared\")\n            lmin_broadcast_local = _var(\"float32\")\n            lmin = T.sblock_alloc_buffer((pN,), \"float32\", scope=\"local\")\n            cmin = T.sblock_alloc_buffer((pN,), \"int32\", scope=\"local\")\n            total_sum = _var(\"float32\")\n\n            it = _var(\"int32\")\n            es_local = _var(\"bool\")\n            es = T.sblock_alloc_buffer((1,), \"bool\", scope=\"shared\")\n            find_pivot_local = _var(\"bool\")\n            find_pivot = T.sblock_alloc_buffer((1,), \"bool\", scope=\"shared\")\n\n            total_sum_reduce = _var(\"float32\")\n            lsum_reduce = _var(\"float32\")\n            lmin_reduce = _var(\"float32\")\n            cmin_reduce = _var(\"int32\")\n\n            for _bx in T.thread_binding(0, B, thread=\"blockIdx.x\"):\n                for _tx in T.thread_binding(0, TX, thread=\"threadIdx.x\"):\n                    with T.sblock(\"CTA\"):\n                        b, tx = T.axis.remap(\"SS\", [_bx, _tx])\n\n                        top_p[0] = top_p_arr[b]\n\n                        if tx == 0:\n                            # leader thread initializes L, R\n                            L[0] = 1.0 - top_p[0]\n                            R[0] = eps_LR\n                            find_pivot[0] = False\n                        T.tvm_storage_sync(\"shared\")\n\n                        L_local[0] = L[0]\n                        R_local[0] = R[0]\n                        for i in T.unroll(0, pN):\n                            # pivots are in descending order\n                            pivot[i] = init_pivots[b, i]\n                        find_pivot_local[0] = False\n                        if L_local[0] - R_local[0] <= eps_LR:\n                            # When the initial value is too small, set the result directly.\n                            if tx == 0:\n                                final_lsum[b] = 1.0\n                                final_pivot[b] = 0.0\n                            find_pivot_local[0] = True\n\n                        while T.tvm_thread_invariant(\n                            L_local[0] - R_local[0] > eps_LR\n                            and T.Not(find_pivot_local[0])\n                        ):\n                            # sync before each iteration\n                            T.tvm_storage_sync(\"shared\")\n\n                            ### get lsum, lmin, total_sum\n                            for pidx in T.unroll(0, pN):\n                                lsum[pidx] = 0.0\n                                lmin[pidx] = T.max_value(\"float32\")\n                                cmin[pidx] = 0\n                            total_sum[0] = 0.0\n                            it[0] = 0\n                            es_local[0] = False\n                            while it[0] < T.ceildiv(N, TX) and T.Not(es_local[0]):\n                                idx = T.meta_var(it[0] * TX + tx)\n                                q[0] = T.if_then_else(idx < N, prob[b, idx], 0.0)\n                                total_sum[0] += q[0]\n                                for pidx in T.unroll(0, pN):\n                                    if q[0] >= pivot[pidx]:\n                                        lsum[pidx] += q[0]\n                                        if lmin[pidx] > q[0]:\n                                            lmin[pidx] = q[0]\n                                            cmin[pidx] = 1\n                                        elif lmin[pidx] == q[0]:\n                                            cmin[pidx] += 1\n                                it[0] += 1\n\n                                # early stop every K iterations\n                                if it[0] % K == 0:\n                                    # reduce total_sum over tx\n                                    # T.tvm_storage_sync(\"shared\")\n                                    with T.sblock(\"block_cross_thread\"):\n                                        T.reads(total_sum[0])\n                                        T.writes(total_sum_reduce[0])\n                                        T.attr(\n                                            T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),\n                                            \"reduce_scope\",\n                                            T.reinterpret(\"handle\", T.uint64(0)),\n                                        )\n                                        T.tvm_thread_allreduce(T.uint32(1), total_sum[0], True, total_sum_reduce[0], tx, dtype=\"handle\")\n                                    # T.tvm_storage_sync(\"shared\")\n\n                                    if tx == 0:\n                                        # leader thread checks if we can stop early\n                                        es[0] = 1 - total_sum_reduce[0] < pivot[pN - 1]\n                                    T.tvm_storage_sync(\"shared\")\n                                    es_local[0] = es[0]\n\n                            T.tvm_storage_sync(\"shared\")\n\n                            # reduce lsum, lmin, cmin, over tx\n                            for pidx in T.serial(0, pN):\n                                # reduce lsum over tx for pivot[j]\n                                with T.sblock(\"block_cross_thread\"):\n                                    T.reads(lsum[pidx])\n                                    T.writes(lsum_reduce[0])\n                                    T.attr(\n                                        T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),\n                                        \"reduce_scope\",\n                                        T.reinterpret(\"handle\", T.uint64(0)),\n                                    )\n                                    T.tvm_thread_allreduce(T.uint32(1), lsum[pidx], True, lsum_reduce[0], tx, dtype=\"handle\")\n\n                                # reduce lmin over tx for pivot[j]\n                                with T.sblock(\"block_cross_thread\"):\n                                    T.reads(lmin[pidx])\n                                    T.writes(lmin_reduce[0])\n                                    T.attr(\n                                        T.comm_reducer(lambda x0, y0: T.min(x0, y0), [T.float32(0)]),\n                                        \"reduce_scope\",\n                                        T.reinterpret(\"handle\", T.uint64(0)),\n                                    )\n                                    T.tvm_thread_allreduce(T.uint32(1), lmin[pidx], True, lmin_reduce[0], tx, dtype=\"handle\")\n\n                                if tx == 0:\n                                    # broadcast lmin to all threads\n                                    lmin_broadcast[0] = lmin_reduce[0]\n                                T.tvm_storage_sync(\"shared\")\n                                lmin_broadcast_local[0] = lmin_broadcast[0]\n                                if lmin[pidx] > lmin_broadcast_local[0]:\n                                    cmin[pidx] = 0\n                                if tx == 0:\n                                    # only the leader thread updates lsum, lmin\n                                    lsum[pidx] = lsum_reduce[0]\n                                    lmin[pidx] = lmin_reduce[0]\n\n                                # reduce cmin over tx for pivot[j]\n                                with T.sblock(\"block_cross_thread\"):\n                                    T.reads(cmin[pidx])\n                                    T.writes(cmin_reduce[0])\n                                    T.attr(\n                                        T.comm_reducer(lambda x0, y0: x0 + y0, [T.int32(0)]),\n                                        \"reduce_scope\",\n                                        T.reinterpret(\"handle\", T.uint64(0)),\n                                    )\n                                    T.tvm_thread_allreduce(T.uint32(1), cmin[pidx], True, cmin_reduce[0], tx, dtype=\"handle\")\n\n                                if tx == 0:\n                                    # only the leader thread updates cmin\n                                    cmin[pidx] = cmin_reduce[0]\n\n                            T.tvm_storage_sync(\"shared\")\n\n                            if tx == 0:\n                                # leader thread checks if we have found the pivot, or updates L, R\n                                it[0] = 0\n                                while it[0] < pN and T.Not(find_pivot_local[0]):\n                                    pidx = T.meta_var(it[0])\n                                    if valid(lsum[pidx], lmin[pidx], cmin[pidx], top_p[0]):\n                                        find_pivot[0] = True\n                                        find_pivot_local[0] = True\n                                        # write back the pivot and lsum\n                                        final_pivot[b] = pivot[pidx]\n                                        final_lsum[b] = lsum[pidx]\n                                    elif lsum[pidx] - lmin[pidx] * cmin[pidx] >= top_p[0]:\n                                        R[0] = pivot[pidx]\n                                        final_lsum[b] = lsum[pidx]\n                                    elif lsum[pidx] < top_p[0]:\n                                        L[0] = pivot[pidx]\n                                    it[0] += 1\n\n                            T.tvm_storage_sync(\"shared\")\n\n                            L_local[0] = L[0]\n                            R_local[0] = R[0]\n                            find_pivot_local[0] = find_pivot[0]\n                            # new pivots for next iteration\n                            # uniform spacing between L and R\n                            for pidx in T.unroll(0, pN):\n                                pivot[pidx] = L[0] - (pidx + 1) * (L_local[0] - R_local[0]) / (pN + 1)\n\n                        if tx == 0:\n                            # leader thread writes back the pivot\n                            if T.Not(find_pivot_local[0]):\n                                final_pivot[b] = R_local[0]\n                                if R_local[0] == eps_LR:\n                                    final_lsum[b] = lsum[pN - 1]\n    # fmt: on\n\n    return _func\n\n\ndef top_p_renorm(target: tvm.target.Target = None):\n    \"\"\"Top-p renormalization function. This function renormalizes the probability vector.\n\n    Given the pivot, the probability vector is renormalized as follows:\n    - if prob >= pivot, renorm_prob = prob / lsum\n    - otherwise, renorm_prob = 0\n\n    Parameters\n    ----------\n    prob:\n        The probability vector\n\n    final_pivot:\n        The final pivot to cut-off top-p percentile\n\n    final_lsum:\n        The sum of elements that are larger or equal to the pivot\n\n    renorm_prob:\n        The renormalized probability vector\n    \"\"\"\n    TX = 1024\n    CTA_COUNT = 512\n\n    if target:\n        max_num_threads_per_block = get_max_num_threads_per_block(target)\n        TX = min(TX, max_num_threads_per_block)\n\n    def _var(dtype=\"int32\"):\n        return T.sblock_alloc_buffer((1,), dtype, scope=\"local\")\n\n    # fmt: off\n    @T.prim_func(private=True)\n    def _func(\n        var_prob: T.handle,\n        var_final_pivot: T.handle,\n        var_final_lsum: T.handle,\n        var_renorm_prob: T.handle,\n    ):\n        T.func_attr({\"tir.is_scheduled\": 1, \"tir.noalias\": True})\n        B = T.int32(is_size_var=True)\n        N = T.int32(is_size_var=True)\n        prob = T.match_buffer(var_prob, (B, N,), \"float32\")\n        final_pivot = T.match_buffer(var_final_pivot, (B,), \"float32\")\n        final_lsum = T.match_buffer(var_final_lsum, (B,), \"float32\")\n        renorm_prob = T.match_buffer(var_renorm_prob, (B, N,), \"float32\")\n\n        with T.sblock(\"kernel\"):\n            pivot = _var(\"float32\")\n            lsum = _var(\"float32\")\n            BX = T.meta_var(T.ceildiv(CTA_COUNT, B))\n\n            for _by in T.thread_binding(0, B, thread=\"blockIdx.y\"):\n                for _bx in T.thread_binding(0, BX, thread=\"blockIdx.x\"):\n                    for _tx in T.thread_binding(0, TX, thread=\"threadIdx.x\"):\n                        with T.sblock(\"CTA\"):\n                            by, bx, tx = T.axis.remap(\"SSS\", [_by, _bx, _tx])\n\n                            pivot[0] = final_pivot[by]\n                            lsum[0] = final_lsum[by]\n\n                            for i in T.serial(T.ceildiv(N, BX * TX)):\n                                idx = T.meta_var(i * BX * TX + bx * TX + tx)\n                                if idx < N:\n                                    renorm_prob[by, idx] = T.if_then_else(prob[by, idx] >= pivot[0], prob[by, idx] / lsum[0], 0.0)\n    # fmt: on\n\n    return _func\n"
  },
  {
    "path": "python/mlc_llm/op/triton.py",
    "content": "\"\"\"Operators enabled by external modules.\"\"\"\n\n# pylint: disable=invalid-name\n\nfrom typing import List, Literal, Tuple\n\nimport tvm\nfrom tvm.relax.frontend import nn\nfrom tvm.script import ir as I\nfrom tvm.script import tir as T\n\ntry:\n    import triton\n    import triton.language as tl\nexcept ImportError:\n    triton = None\n    tl = None\n\n\n# We use a wrapper function to avoid type annotation issue of \"tl.constexpr\" when\n# triton is not installed.\ndef _get_triton_w8a8_block_fp8_gemm():\n    # Triton kernel adapted from SGLang project\n    # https://github.com/sgl-project/sglang/blob/v0.4.4/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py  # pylint: disable=line-too-long\n    def _triton_w8a8_block_fp8_gemm(  # pylint: disable=too-many-arguments,too-many-locals\n        # Pointers to inputs and output\n        A,\n        B,\n        C,\n        As,\n        Bs,\n        # Shape for matmul\n        M,\n        N: tl.constexpr,\n        K: tl.constexpr,\n        # Stride for inputs and output\n        stride_am: tl.constexpr,\n        stride_ak: tl.constexpr,\n        stride_bk: tl.constexpr,\n        stride_bn: tl.constexpr,\n        stride_cm: tl.constexpr,\n        stride_cn: tl.constexpr,\n        stride_As_m: tl.constexpr,\n        stride_As_k: tl.constexpr,\n        stride_Bs_k: tl.constexpr,\n        stride_Bs_n: tl.constexpr,\n        # Block size for block-wise quantization\n        group_n: tl.constexpr,\n        group_k: tl.constexpr,\n        # Meta-parameters\n        BLOCK_SIZE_M: tl.constexpr,\n        BLOCK_SIZE_N: tl.constexpr,\n        BLOCK_SIZE_K: tl.constexpr,\n        GROUP_SIZE_M: tl.constexpr,\n    ):\n        \"\"\"Triton-accelerated function used to perform linear operations (dot\n        product) on input tensors `A` and `B` with block-wise quantization,\n        and store the result in output tensor `C`.\n        \"\"\"\n\n        pid = tl.program_id(axis=0).to(tl.int64)\n        num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n        num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n        num_pid_in_group = GROUP_SIZE_M * num_pid_n\n        group_id = pid // num_pid_in_group\n        first_pid_m = group_id * GROUP_SIZE_M\n        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n        pid_m = first_pid_m + (pid % group_size_m)\n        pid_n = (pid % num_pid_in_group) // group_size_m\n\n        offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n        offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n        offs_k = tl.arange(0, BLOCK_SIZE_K)\n        a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n        b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n        As_ptrs = As + offs_am * stride_As_m\n        offs_bsn = offs_bn // group_n\n        Bs_ptrs = Bs + offs_bsn * stride_Bs_n\n\n        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n        for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n            a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n            b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n\n            k_start = k * BLOCK_SIZE_K\n            offs_ks = k_start // group_k\n            a_s = tl.load(As_ptrs + offs_ks * stride_As_k)\n            b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)\n\n            accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]\n            a_ptrs += BLOCK_SIZE_K * stride_ak\n            b_ptrs += BLOCK_SIZE_K * stride_bk\n\n        if C.dtype.element_ty == tl.bfloat16:\n            c = accumulator.to(tl.bfloat16)\n        elif C.dtype.element_ty == tl.float16:\n            c = accumulator.to(tl.float16)\n        else:\n            c = accumulator.to(tl.float32)\n\n        offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n        offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n        c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n        c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n        tl.store(c_ptrs, c, mask=c_mask)\n\n    return _triton_w8a8_block_fp8_gemm\n\n\n# We use a wrapper function to avoid type annotation issue of \"tl.constexpr\" when\n# triton is not installed.\ndef _get_triton_w8a8_block_fp8_group_gemm():\n    # Triton kernel adapted from SGLang project\n    # https://github.com/sgl-project/sglang/blob/v0.4.4/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py  # pylint: disable=line-too-long\n    def _triton_w8a8_block_fp8_group_gemm(  # pylint: disable=too-many-arguments,too-many-locals\n        # Pointers to matrices\n        a_ptr,\n        b_ptr,\n        c_ptr,\n        a_scale_ptr,\n        b_scale_ptr,\n        expert_ids_ptr,\n        indptr_ptr,\n        # Matrix dimensions\n        EM,\n        N: tl.constexpr,\n        K: tl.constexpr,\n        num_experts: tl.constexpr,\n        # The stride variables represent how much to increase the ptr by when\n        # moving by 1 element in a particular dimension. E.g. `stride_am` is\n        # how much to increase `a_ptr` by to get the element one row down\n        # (A has M rows).\n        stride_am: tl.constexpr,\n        stride_ak: tl.constexpr,\n        stride_be: tl.constexpr,\n        stride_bk: tl.constexpr,\n        stride_bn: tl.constexpr,\n        stride_cm: tl.constexpr,\n        stride_cn: tl.constexpr,\n        stride_asm: tl.constexpr,\n        stride_ask: tl.constexpr,\n        stride_bse: tl.constexpr,\n        stride_bsk: tl.constexpr,\n        stride_bsn: tl.constexpr,\n        # Block size for block-wise quantization\n        group_n: tl.constexpr,\n        group_k: tl.constexpr,\n        # Meta-parameters\n        BLOCK_SIZE_M: tl.constexpr,\n        BLOCK_SIZE_N: tl.constexpr,\n        BLOCK_SIZE_K: tl.constexpr,\n        GROUP_SIZE_M: tl.constexpr,\n        even_Ks: tl.constexpr,\n    ):\n        \"\"\"\n        Implements the fused computation for a Mixture of Experts (MOE) using\n        token and expert matrices.\n\n        Key Parameters:\n        - A: The input tensor representing tokens with shape (*, K), where '*' can\n            be any shape representing batches and K is the feature dimension of\n            each token.\n        - B: The stacked MOE weight tensor with shape (E, N, K), where E is\n            the number of experts, K is the input feature dimension, and N is\n            the output feature dimension.\n        - C: The output cache tensor with shape (*, N), where '*' means the\n            same shape as the input tensor A, and N is the output feature dimension.\n        - expert_ids: A tensor containing the indices of the expert for each\n            block. It determines which expert matrix from B should be used for\n            each block in A.\n        This kernel performs the multiplication of a token by its corresponding\n        expert matrix as determined by `expert_ids`.\n        \"\"\"\n        # -----------------------------------------------------------\n        # Map program ids `pid` to the block of C it should compute.\n        # This is done in a grouped ordering to promote L2 data reuse.\n        pid = tl.program_id(axis=0).to(tl.int64)\n        num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_experts\n        num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n        num_pid_in_group = GROUP_SIZE_M * num_pid_n\n        group_id = pid // num_pid_in_group\n        first_pid_m = group_id * GROUP_SIZE_M\n        group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n        pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n        pid_n = (pid % num_pid_in_group) // group_size_m\n\n        # ----------------------------------------------------------\n        # Create pointers for the first blocks of A and B.\n        # We will advance this pointer as we move in the K direction\n        # and accumulate\n        # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers\n        # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers\n        expert_id = tl.load(expert_ids_ptr + pid_m).to(tl.int64)\n        if expert_id == -1:\n            return\n\n        token_begin = tl.load(indptr_ptr + expert_id)\n        token_end = tl.load(indptr_ptr + expert_id + 1)\n        start_pid_m = tl.cdiv(token_begin, BLOCK_SIZE_M) + expert_id\n        offs_token_id = (\n            token_begin + (pid_m - start_pid_m) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n        )\n        token_mask = offs_token_id < token_end\n\n        offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n        offs_k = tl.arange(0, BLOCK_SIZE_K)\n        a_ptrs = a_ptr + offs_token_id[:, None] * stride_am + offs_k[None, :] * stride_ak\n\n        b_ptrs = (\n            b_ptr\n            + expert_id * stride_be\n            + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n        )\n\n        a_scale_ptrs = a_scale_ptr + offs_token_id * stride_asm\n        offs_bsn = offs_bn // group_n\n        b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_bsn * stride_bsn\n\n        # -----------------------------------------------------------\n        # Iterate to compute a block of the C matrix.\n        # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block\n        # of fp32 values for higher accuracy.\n        # `accumulator` will be converted back to fp16 after the loop.\n        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n        for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n            # Load the next block of A and B, generate a mask by checking the\n            # K dimension.\n            if even_Ks:\n                a = tl.load(\n                    a_ptrs,\n                    mask=token_mask[:, None],\n                    other=0.0,\n                )\n                b = tl.load(b_ptrs)\n            else:\n                a = tl.load(\n                    a_ptrs,\n                    mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),\n                    other=0.0,\n                )\n                b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n\n            # We accumulate along the K dimension.\n            k_start = k * BLOCK_SIZE_K\n            offs_ks = k_start // group_k\n            a_scale = tl.load(a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0)\n            b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)\n\n            accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]\n            # Advance the ptrs to the next K block.\n            a_ptrs += BLOCK_SIZE_K * stride_ak\n            b_ptrs += BLOCK_SIZE_K * stride_bk\n\n        if c_ptr.dtype.element_ty == tl.bfloat16:\n            accumulator = accumulator.to(tl.bfloat16)\n        elif c_ptr.dtype.element_ty == tl.float16:\n            accumulator = accumulator.to(tl.float16)\n        else:\n            accumulator = accumulator.to(tl.float32)\n\n        # -----------------------------------------------------------\n        # Write back the block of the output\n        offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n        c_ptrs = c_ptr + stride_cm * offs_token_id[:, None] + stride_cn * offs_cn[None, :]\n        c_mask = token_mask[:, None] & (offs_cn[None, :] < N)\n        tl.store(c_ptrs, accumulator, mask=c_mask)\n\n    return _triton_w8a8_block_fp8_group_gemm\n\n\ndef get_tir_w8a8_block_fp8_matmul(  # pylint: disable=too-many-arguments,too-many-locals\n    N: int,\n    K: int,\n    block_n: int,\n    block_k: int,\n    in_dtype: Literal[\"float8_e4m3fn\"],\n    out_dtype: Literal[\"float16\", \"bfloat16\"],\n    BLOCK_SIZE_M: int,\n    BLOCK_SIZE_N: int,\n    BLOCK_SIZE_K: int,\n    GROUP_SIZE_M: int,\n    num_warps: int,\n    num_stages: int,\n    extern_mods: List[tvm.runtime.Module],\n):\n    \"\"\"Get the TIR function for the w8a8_block_fp8_matmul kernel.\"\"\"\n    # NOTE: adding the type annotation of \" -> Tuple[Optional[tvm.tir.PrimFunc], str]\"\n    # will cause the failure of the type resolution in mypy.\n    if triton is None:\n        raise RuntimeError(\"Triton is not installed. Please install it with `pip install triton`.\")\n\n    name_suffix = f\"_N{N}_K{K}_block_n{block_n}_block_k{block_k}_in{in_dtype}_out{out_dtype}\"\n    kernel_name = f\"triton_w8a8_block_fp8_gemm{name_suffix}\"\n    tir_name = f\"tir_w8a8_block_fp8_matmul{name_suffix}\"\n    for ext_mod in extern_mods:\n        if ext_mod.implements_function(kernel_name):\n            return [None, tir_name]\n\n    triton_kernel = _get_triton_w8a8_block_fp8_gemm()\n    triton_kernel.__name__ = kernel_name\n\n    @I.ir_module\n    class BlockFP8Matmul:  # pylint: disable=missing-class-docstring,too-few-public-methods\n        @T.prim_func(private=True)\n        def tir_w8a8_block_fp8_matmul(  # pylint: disable=missing-function-docstring\n            var_A: T.handle,\n            var_B: T.handle,\n            var_As: T.handle,\n            var_Bs: T.handle,\n            var_C: T.handle,\n        ):\n            T.func_attr({\"op_pattern\": 8, \"tir.is_scheduled\": 1})\n            M = T.SizeVar(\"M\", \"int32\")\n            A = T.match_buffer(var_A, (M, K), dtype=in_dtype)\n            B = T.match_buffer(var_B, (N, K), dtype=in_dtype)\n            As = T.match_buffer(var_As, (M, (K + block_k - 1) // block_k), \"float32\")\n            Bs = T.match_buffer(\n                var_Bs,\n                ((N + block_n - 1) // block_n, (K + block_k - 1) // block_k),\n                \"float32\",\n            )\n            C = T.match_buffer(var_C, (M, N), dtype=out_dtype)\n            with T.sblock(\"root\"):\n                T.reads(\n                    A[0:M, 0:K],\n                    B[0:N, 0:K],\n                    As[0:M, 0 : (K + block_k - 1) // block_k],\n                    Bs[\n                        0 : (N + block_n - 1) // block_n,\n                        0 : (K + block_k - 1) // block_k,\n                    ],\n                )\n                T.writes(C[0:M, 0:N])\n                T.call_kernel(\n                    triton.jit(triton_kernel),\n                    (T.ceildiv(M, BLOCK_SIZE_M) * T.ceildiv(N, BLOCK_SIZE_N),),\n                    A.data,\n                    B.data,\n                    C.data,\n                    As.data,\n                    Bs.data,\n                    M,\n                    N,\n                    K,\n                    K,  # stride_am\n                    1,  # stride_ak\n                    1,  # stride_bk\n                    K,  # stride_bn\n                    N,  # stride_cm\n                    1,  # stride_cn\n                    (K + block_k - 1) // block_k,  # stride_As_m\n                    1,  # stride_As_k\n                    1,  # stride_Bs_k\n                    (K + block_k - 1) // block_k,  # stride_Bs_n\n                    block_n,\n                    block_k,\n                    BLOCK_SIZE_M,\n                    BLOCK_SIZE_N,\n                    BLOCK_SIZE_K,\n                    GROUP_SIZE_M,\n                    num_warps=num_warps,\n                    num_stages=num_stages,\n                )\n\n    new_ext_mods = BlockFP8Matmul.attrs[\"external_mods\"]  # type: ignore  # pylint: disable=no-member\n    assert len(new_ext_mods) == 1\n    extern_mods.append(new_ext_mods[0])\n    return BlockFP8Matmul[\"tir_w8a8_block_fp8_matmul\"], tir_name  # type: ignore\n\n\ndef get_tir_w8a8_block_fp8_group_matmul(  # pylint: disable=too-many-arguments,too-many-locals\n    N: int,\n    K: int,\n    num_experts: int,\n    block_n: int,\n    block_k: int,\n    in_dtype: Literal[\"float8_e4m3fn\"],\n    out_dtype: Literal[\"float16\", \"bfloat16\"],\n    BLOCK_SIZE_M: int,\n    BLOCK_SIZE_N: int,\n    BLOCK_SIZE_K: int,\n    GROUP_SIZE_M: int,\n    num_warps: int,\n    num_stages: int,\n    extern_mods: List[tvm.runtime.Module],\n):\n    \"\"\"Get the TIR function for the w8a8_block_fp8_group_gemm kernel.\"\"\"\n    if triton is None:\n        raise RuntimeError(\"Triton is not installed. Please install it with `pip install triton`.\")\n\n    name_suffix = (\n        f\"_N{N}_K{K}_num_experts{num_experts}_block_n{block_n}\"\n        f\"_block_k{block_k}_in{in_dtype}_out{out_dtype}\"\n    )\n    kernel_name = f\"triton_w8a8_block_fp8_group_gemm{name_suffix}\"\n    tir_name = f\"tir_w8a8_block_fp8_group_gemm{name_suffix}\"\n    for ext_mod in extern_mods:\n        if ext_mod.implements_function(kernel_name):\n            return [None, tir_name]\n\n    triton_kernel = _get_triton_w8a8_block_fp8_group_gemm()\n    triton_kernel.__name__ = kernel_name\n\n    @I.ir_module\n    class BlockFP8GroupMatmul:  # pylint: disable=missing-class-docstring,too-few-public-methods\n        @T.prim_func(private=True)\n        def tir_w8a8_block_fp8_group_gemm(  # pylint: disable=missing-function-docstring,too-many-arguments\n            var_A: T.handle,\n            var_B: T.handle,\n            var_As: T.handle,\n            var_Bs: T.handle,\n            var_expert_ids: T.handle,\n            var_indptr: T.handle,\n            var_C: T.handle,\n        ):\n            T.func_attr({\"op_pattern\": 8, \"tir.is_scheduled\": 1})\n            EM = T.SizeVar(\"EM\", \"int32\")\n            A = T.match_buffer(var_A, (EM, K), dtype=in_dtype)\n            B = T.match_buffer(var_B, (num_experts, N, K), dtype=in_dtype)\n            As = T.match_buffer(var_As, (EM, (K + block_k - 1) // block_k), \"float32\")\n            Bs = T.match_buffer(\n                var_Bs,\n                (\n                    num_experts,\n                    (N + block_n - 1) // block_n,\n                    (K + block_k - 1) // block_k,\n                ),\n                \"float32\",\n            )\n            expert_ids = T.match_buffer(\n                var_expert_ids,\n                ((EM + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + num_experts,),\n                \"int32\",\n            )\n            indptr = T.match_buffer(var_indptr, (num_experts + 1,), \"int32\")\n            C = T.match_buffer(var_C, (EM, N), dtype=out_dtype)\n\n            with T.sblock(\"root\"):\n                T.reads(\n                    A[0:EM, 0:K],\n                    B[0:num_experts, 0:N, 0:K],\n                    As[0:EM, 0 : (K + block_k - 1) // block_k],\n                    Bs[\n                        0:num_experts,\n                        0 : (N + block_n - 1) // block_n,\n                        0 : (K + block_k - 1) // block_k,\n                    ],\n                    expert_ids[0 : (EM + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + num_experts],\n                    indptr[0 : num_experts + 1],\n                )\n                T.writes(C[0:EM, 0:N])\n                T.call_kernel(\n                    triton.jit(triton_kernel),\n                    ((T.ceildiv(EM, BLOCK_SIZE_M) + num_experts) * T.ceildiv(N, BLOCK_SIZE_N),),\n                    A.data,\n                    B.data,\n                    C.data,\n                    As.data,\n                    Bs.data,\n                    expert_ids.data,\n                    indptr.data,\n                    EM,\n                    N,\n                    K,\n                    num_experts,\n                    K,  # stride_am\n                    1,  # stride_ak\n                    N * K,  # stride_be\n                    1,  # stride_bk\n                    K,  # stride_bn\n                    N,  # stride_cm\n                    1,  # stride_cn\n                    (K + block_k - 1) // block_k,  # stride_asm\n                    1,  # stride_ask\n                    ((N + block_n - 1) // block_n) * ((K + block_k - 1) // block_k),  # stride_bse\n                    1,  # stride_bsk\n                    (K + block_k - 1) // block_k,  # stride_Bs_n\n                    block_n,\n                    block_k,\n                    BLOCK_SIZE_M,\n                    BLOCK_SIZE_N,\n                    BLOCK_SIZE_K,\n                    GROUP_SIZE_M,\n                    K % BLOCK_SIZE_K == 0,\n                    num_warps=num_warps,\n                    num_stages=num_stages,\n                )\n\n    new_ext_mods = BlockFP8GroupMatmul.attrs[\"external_mods\"]  # type: ignore  # pylint: disable=no-member\n    assert len(new_ext_mods) == 1\n    extern_mods.append(new_ext_mods[0])\n    return BlockFP8GroupMatmul[\"tir_w8a8_block_fp8_group_gemm\"], tir_name  # type: ignore\n\n\ndef _compute_expert_id_per_block(\n    indptr: nn.Tensor,\n    num_experts: int,\n    M: nn.IntExpr,\n    BLOCK_SIZE_M: int,\n) -> nn.Tensor:\n    \"\"\"Compute the expert id for each threadblock (CTA).\n    We assign an expert id to each threadblock, and the threadblock will\n    compute the gemm with regard to the specified expert.\n\n    Parameters\n    ----------\n    indptr : nn.Tensor\n        The indptr tensor of group gemm, with shape of [num_experts + 1,].\n\n    num_experts : int\n        The number of total experts.\n\n    M : nn.IntExpr\n        The number of tokens.\n\n    BLOCK_SIZE_M : int\n        The block size of the threadblock along the batch dimension.\n\n    Returns\n    -------\n    expert_ids : nn.Tensor\n        The expert id for each threadblock, with shape of\n        [(M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + num_experts,].\n    \"\"\"\n\n    @T.prim_func\n    def tir_compute_expert_id_per_block(\n        var_indptr: T.handle,\n        var_expert_ids: T.handle,\n        M: T.int64,\n    ):\n        T.func_attr({\"op_pattern\": 8, \"tir.is_scheduled\": 1})\n        indptr = T.match_buffer(var_indptr, (num_experts + 1,), \"int32\")\n        expert_ids = T.match_buffer(\n            var_expert_ids,\n            ((M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + num_experts,),\n            \"int32\",\n        )\n        with T.sblock(\"root\"):\n            for eid in T.thread_binding(0, num_experts, thread=\"threadIdx.x\"):\n                start_block_id: T.int32 = (indptr[eid] + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + eid\n                num_blocks: T.int32 = (\n                    indptr[eid + 1] - indptr[eid] + BLOCK_SIZE_M - 1\n                ) // BLOCK_SIZE_M\n                start_block_id_next_expert: T.int32 = (\n                    (indptr[eid + 1] + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + eid + 1\n                )\n                for block_id in T.serial(num_blocks):\n                    expert_ids[start_block_id + block_id] = eid\n                for block_id in T.serial(\n                    start_block_id_next_expert - (start_block_id + num_blocks)\n                ):\n                    expert_ids[start_block_id + num_blocks + block_id] = -1\n\n    assert num_experts <= 1024\n    return nn.tensor_ir_op(\n        tir_compute_expert_id_per_block,\n        \"tir_compute_expert_id_per_block\",\n        args=[indptr, M],\n        out=nn.Tensor.placeholder(\n            ((M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M + num_experts,), dtype=\"int32\"\n        ),\n    )\n\n\ndef fp8_groupwise_scaled_gemm(  # pylint: disable=too-many-arguments,too-many-locals\n    x: nn.Tensor,\n    x_scale: nn.Tensor,\n    weight: nn.Tensor,\n    weight_scale: nn.Tensor,\n    block_size: Tuple[int, int],\n    out_dtype: str,\n) -> nn.Tensor:\n    \"\"\"Triton block-scale fp8 gemm operator.\n\n    Parameters\n    ----------\n    x : nn.Tensor\n        The input tensor, with shape of [m, k].\n\n    x_scale : nn.Tensor\n        The scale tensor, with shape of [m, k // block_size].\n\n    weight : nn.Tensor\n        The weight tensor, with shape of [n, k].\n\n    weight_scale : nn.Tensor\n        The scale tensor, with shape of [n // block_size, k // block_size].\n\n    block_size : Tuple[int, int]\n        The block size.\n\n    out_dtype : str\n        The data type of the output tensor.\n\n    Returns\n    -------\n    out : nn.Tensor\n        The output tensor, with shape of [m, n] and dtype of `out_dtype`.\n    \"\"\"\n    assert x.ndim >= 2\n    assert weight.ndim == 2\n    assert x_scale.ndim == x.ndim\n    assert weight_scale.ndim == weight.ndim\n    assert x.shape[-1] == weight.shape[1]\n    assert x.shape[:-1] == x_scale.shape[:-1]\n    assert (x.shape[-1] + block_size[1] - 1) // block_size[1] == x_scale.shape[-1]\n    assert (weight.shape[1] + block_size[1] - 1) // block_size[1] == weight_scale.shape[1]\n    assert (weight.shape[0] + block_size[0] - 1) // block_size[0] == weight_scale.shape[0]\n\n    if x.dtype != \"float8_e4m3fn\" or weight.dtype != \"float8_e4m3fn\":\n        raise ValueError(\n            f\"x and weight must be float8_e4m3fn, but got x={x.dtype}, weight={weight.dtype}\"\n        )\n    if x_scale.dtype != \"float32\" and weight_scale.dtype != \"float32\":\n        raise ValueError(\n            \"x_scale and weight_scale must be float32, but got \"\n            f\"x_scale={x_scale.dtype}, weight_scale={weight_scale.dtype}\"\n        )\n    if out_dtype not in [\"float16\", \"bfloat16\"]:\n        raise ValueError(f\"out_dtype must be float16 or bfloat16, but got {out_dtype}\")\n\n    M = x.shape[0]\n    for i in range(1, x.ndim - 1):\n        M *= x.shape[i]\n    N = weight.shape[0]\n    K = x.shape[-1]\n\n    BLOCK_SIZE_M = 64\n    BLOCK_SIZE_N = block_size[0]\n    BLOCK_SIZE_K = block_size[1]\n    GROUP_SIZE_M = 32\n    num_warps = 4\n    num_stages = 3\n\n    x_shape = x.shape\n    if x.ndim > 2:\n        x = x.reshape(M, K)\n    x_scale = x_scale.reshape(M, x_scale.shape[-1])\n\n    out = nn.extern(\n        \"mlc.triton.w8a8_block_fp8_matmul\",\n        args=[\n            x,\n            weight,\n            x_scale,\n            weight_scale,\n            N,\n            K,\n            block_size[0],\n            block_size[1],\n            BLOCK_SIZE_M,\n            BLOCK_SIZE_N,\n            BLOCK_SIZE_K,\n            GROUP_SIZE_M,\n            num_warps,\n            num_stages,\n            str(x.dtype),\n            str(out_dtype),\n        ],\n        out=nn.Tensor.placeholder((M, N), dtype=out_dtype),\n    )\n    return out.reshape(*x_shape[:-1], N) if len(x_shape) > 2 else out\n\n\ndef fp8_groupwise_scaled_group_gemm(  # pylint: disable=too-many-arguments,too-many-locals\n    x: nn.Tensor,\n    x_scale: nn.Tensor,\n    weight: nn.Tensor,\n    weight_scale: nn.Tensor,\n    indptr: nn.Tensor,\n    block_size: Tuple[int, int],\n    out_dtype: str,\n):\n    \"\"\"Triton block-scale fp8 group gemm operator.\n\n    Parameters\n    ----------\n    x : nn.Tensor\n        The input tensor, with shape of [m, k].\n\n    x_scale : nn.Tensor\n        The scale tensor, with shape of [m, k // block_size].\n\n    weight : nn.Tensor\n        The weight tensor, with shape of [num_experts, n, k].\n\n    weight_scale : nn.Tensor\n        The scale tensor, with shape of [num_experts, n // block_size, k // block_size].\n\n    indptr : nn.Tensor\n        The indptr tensor of group gemm, with shape of [num_experts + 1,].\n\n    block_size : Tuple[int, int]\n        The block size.\n\n    out_dtype : str\n        The data type of the output tensor.\n\n    Returns\n    -------\n    out : nn.Tensor\n        The output tensor, with shape of [m, n] and dtype of `out_dtype`.\n    \"\"\"\n    assert x.ndim >= 2\n    assert weight.ndim == 3\n    assert x_scale.ndim == x.ndim\n    assert weight_scale.ndim == weight.ndim\n    assert x.shape[-1] == weight.shape[2]\n    assert (x.shape[-1] + block_size[1] - 1) // block_size[1] == x_scale.shape[-1]\n    assert (weight.shape[2] + block_size[1] - 1) // block_size[1] == weight_scale.shape[2]\n    assert (weight.shape[1] + block_size[0] - 1) // block_size[0] == weight_scale.shape[1]\n\n    num_experts = weight.shape[0]\n    M = x.shape[0]\n    for i in range(1, x.ndim - 1):\n        M *= x.shape[i]\n    N = weight.shape[1]\n    K = x.shape[-1]\n    assert weight_scale.shape[0] == num_experts\n    assert indptr.ndim == 1\n    assert indptr.shape[0] == num_experts + 1\n\n    BLOCK_SIZE_M = 64\n    BLOCK_SIZE_N = block_size[0]\n    BLOCK_SIZE_K = block_size[1]\n    GROUP_SIZE_M = 32\n    num_warps = 4\n    num_stages = 3\n\n    x_shape = x.shape\n    if x.ndim > 2:\n        x = x.reshape(M, K)\n    x_scale = x_scale.reshape(M, x_scale.shape[-1])\n    expert_ids = _compute_expert_id_per_block(indptr, num_experts, M, BLOCK_SIZE_M)\n\n    out = nn.extern(\n        \"mlc.triton.w8a8_block_fp8_group_matmul\",\n        args=[\n            x,\n            weight,\n            x_scale,\n            weight_scale,\n            expert_ids,\n            indptr,\n            N,\n            K,\n            num_experts,\n            block_size[0],\n            block_size[1],\n            BLOCK_SIZE_M,\n            BLOCK_SIZE_N,\n            BLOCK_SIZE_K,\n            GROUP_SIZE_M,\n            num_warps,\n            num_stages,\n            str(x.dtype),\n            str(out_dtype),\n        ],\n        out=nn.Tensor.placeholder((M, N), dtype=out_dtype),\n    )\n    return out.reshape(*x_shape[:-1], N) if len(x_shape) > 2 else out\n"
  },
  {
    "path": "python/mlc_llm/protocol/__init__.py",
    "content": "\"\"\"Definitions of pydantic models for API entry points and configurations\n\nNote\n----\nWe use the following convention\n\n- filename_protocol If the classes can appear in an API endpoint\n- filename_config For other config classes\n\"\"\"\n"
  },
  {
    "path": "python/mlc_llm/protocol/conversation_protocol.py",
    "content": "\"\"\"The standard conversation protocol in MLC LLM\"\"\"\n\nfrom enum import Enum\nfrom typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union\n\nfrom pydantic import BaseModel, Field, field_validator\n\n\n# The message placeholders in the message prompts according to roles.\nclass MessagePlaceholders(Enum):\n    \"\"\"The message placeholders in the message prompts according to roles.\"\"\"\n\n    SYSTEM = \"{system_message}\"\n    USER = \"{user_message}\"\n    ASSISTANT = \"{assistant_message}\"\n    TOOL = \"{tool_message}\"\n    FUNCTION = \"{function_string}\"\n\n\nT = TypeVar(\"T\", bound=\"BaseModel\")\n\n\nclass Conversation(BaseModel):\n    \"\"\"Class that specifies the convention template of conversation\n    and contains the conversation history.\n\n    Given a conversation template, the corresponding prompt generated out\n    from it is usually in the following format:\n\n      <<system>><<messages[0][0]>><<role_content_sep>><<messages[0][1]>><<seps[0]>>\n                <<messages[1][0]>><<role_content_sep>><<messages[1][1]>><<seps[1]>>\n                ...\n                <<messages[2][0]>><<role_content_sep>><<messages[2][1]>><<seps[0]>>\n                <<roles[1]>><<role_empty_sep>>\n    \"\"\"\n\n    # Optional name of the template.\n    name: Optional[str] = None\n    # The system prompt template, it optionally contains the system\n    # message placeholder, and the placeholder will be replaced with\n    # the system message below.\n    system_template: str = MessagePlaceholders.SYSTEM.value\n    # The content of the system prompt (without the template format).\n    system_message: str = \"\"\n    # The system token ids to be prepended at the beginning of tokenized\n    # generated prompt.\n    system_prefix_token_ids: Optional[List[int]] = None\n    # Whether or not to append user role and separator after the system message.\n    # This is mainly for [INST] [/INST] style prompt format\n    add_role_after_system_message: bool = True\n\n    # The conversation roles\n    roles: Dict[str, str]\n\n    # The roles prompt template, it optionally contains the defaults\n    # message placeholders and will be replaced by actual content\n    role_templates: Dict[str, str]\n\n    # The conversation history messages.\n    # Each message is a pair of strings, denoting \"(role, content)\".\n    # The content can be None.\n    messages: List[Tuple[str, Optional[Union[str, List[Dict]]]]] = Field(default_factory=lambda: [])\n\n    # The separators between messages when concatenating into a single prompt.\n    # List size should be either 1 or 2.\n    # - When size is 1, the separator will be used between adjacent messages.\n    # - When size is 2, seps[0] is used after user message, and\n    #   seps[1] is used after assistant message.\n    seps: List[str]\n\n    # The separator between the role and the content in a message.\n    role_content_sep: str = \"\"\n    # The separator between the role and empty contents.\n    role_empty_sep: str = \"\"\n\n    # The stop criteria\n    stop_str: List[str] = Field(default_factory=lambda: [])\n    stop_token_ids: List[int] = Field(default_factory=lambda: [])\n\n    # Function call fields\n    function_string: str = \"\"\n    # whether using function calling or not, helps check for output message format in API call\n    use_function_calling: bool = False\n\n    def __init__(self, role_templates: Optional[Dict[str, str]] = None, **kwargs):\n        # Defaults templates which would be overridden by model specific templates\n        _role_templates: Dict[str, str] = {\n            \"user\": MessagePlaceholders.USER.value,\n            \"assistant\": MessagePlaceholders.ASSISTANT.value,\n            \"tool\": MessagePlaceholders.TOOL.value,\n        }\n        if role_templates is not None:\n            _role_templates.update(role_templates)\n        super().__init__(role_templates=_role_templates, **kwargs)\n\n    @field_validator(\"seps\")\n    @classmethod\n    def check_message_seps(cls, seps: List[str]) -> List[str]:\n        \"\"\"Check if the input message separators has size 1 or 2.\"\"\"\n        if len(seps) == 0 or len(seps) > 2:\n            raise ValueError(\"seps should have size 1 or 2.\")\n        return seps\n\n    def to_json_dict(self) -> Dict[str, Any]:\n        \"\"\"Convert to a json dictionary\"\"\"\n        return self.model_dump(by_alias=True, exclude_none=True)\n\n    @classmethod\n    def from_json_dict(cls: Type[T], json_dict: Dict[str, Any]) -> T:\n        \"\"\"Convert from a json dictionary\"\"\"\n        return Conversation.model_validate(json_dict)\n\n    # pylint: disable=too-many-branches\n    def as_prompt(self, config=None) -> List[Any]:\n        \"\"\"Convert the conversation template and history messages to\n        a single prompt.\n\n        Returns\n        -------\n        prompts : List[Union[str, \"mlc_llm.serve.data.Data\"]]\n            The prompts converted from the conversation messages.\n            We use Any in the signature to avoid cyclic import.\n        \"\"\"\n        from ..serve import data  # pylint: disable=import-outside-toplevel\n\n        # - Get the system message.\n        system_msg = self.system_template.replace(\n            MessagePlaceholders.SYSTEM.value, self.system_message\n        )\n\n        # - Get the message strings.\n        message_list: List[Union[str, data.Data]] = []\n        separators = list(self.seps)\n        if len(separators) == 1:\n            separators.append(separators[0])\n\n        if system_msg != \"\":\n            message_list.append(system_msg)\n\n        for i, (role, content) in enumerate(self.messages):  # pylint: disable=not-an-iterable\n            if role not in self.roles.keys():\n                raise ValueError(f'Role \"{role}\" is not a supported role in {self.roles.keys()}')\n            separator = separators[role == \"assistant\"]  # check assistant role\n\n            if content is None:\n                message_list.append(self.roles[role] + self.role_empty_sep)\n                continue\n\n            role_prefix = (\n                \"\"\n                # Do not append role prefix if this is the first message and there\n                # is already a system message\n                if (not self.add_role_after_system_message and system_msg != \"\" and i == 0)\n                else self.roles[role] + self.role_content_sep\n            )\n            if isinstance(content, str):\n                message_list.append(\n                    role_prefix\n                    + self.role_templates[role].replace(\n                        MessagePlaceholders[role.upper()].value, content\n                    )\n                    + separator\n                )\n                continue\n\n            message_list.append(role_prefix)\n\n            for item in content:\n                assert isinstance(item, dict), \"Content should be a string or a list of dicts\"\n                assert \"type\" in item, \"Content item should have a type field\"\n                if item[\"type\"] == \"text\":\n                    message = self.role_templates[role].replace(\n                        MessagePlaceholders[role.upper()].value, item[\"text\"]\n                    )\n                    message_list.append(message)\n                elif item[\"type\"] == \"image_url\":\n                    assert config is not None, \"Model config is required\"\n                    image_url = _get_url_from_item(item)\n                    message_list.append(data.ImageData.from_url(image_url, config))\n                    message_list.append(\"\\n\")\n                else:\n                    raise ValueError(f\"Unsupported content type: {item['type']}\")\n\n            message_list.append(separator)\n\n        prompt = _combine_consecutive_messages(message_list)\n\n        if not any(isinstance(item, data.ImageData) for item in message_list):\n            # Replace the last function string placeholder with actual function string\n            prompt[0] = self.function_string.join(\n                prompt[0].rsplit(MessagePlaceholders.FUNCTION.value, 1)\n            )\n            # Replace with remaining function string placeholders with empty string\n            prompt[0] = prompt[0].replace(MessagePlaceholders.FUNCTION.value, \"\")\n\n        return prompt\n\n\ndef _get_url_from_item(item: Dict) -> str:\n    image_url: str\n    assert \"image_url\" in item, \"Content item should have an image_url field\"\n    if isinstance(item[\"image_url\"], str):\n        image_url = item[\"image_url\"]\n    elif isinstance(item[\"image_url\"], dict):\n        assert (\n            \"url\" in item[\"image_url\"]\n        ), \"Content image_url item should be a string or a dict with a url field\"  # pylint: disable=line-too-long\n        image_url = item[\"image_url\"][\"url\"]\n    else:\n        raise ValueError(\n            \"Content image_url item type not supported. \"\n            \"Should be a string or a dict with a url field.\"\n        )\n    return image_url\n\n\ndef _combine_consecutive_messages(messages: List[Any]) -> List[Any]:\n    \"\"\"Combining consecutive strings into one.\n\n    Parameters\n    ----------\n    messages : List[Union[str, \"mlc_llm.serve.data.Data\"]]\n        The input messages to be combined.\n        We use Any in the signature to avoid cyclic import.\n\n    Returns\n    -------\n    updated_messages : List[Union[str, \"mlc_llm.serve.data.Data\"]]\n        The combined messages\n    \"\"\"\n    if len(messages) == 0:\n        return []\n\n    combined_messages = [messages[0]]\n    for message in messages[1:]:\n        if isinstance(message, str) and isinstance(combined_messages[-1], str):\n            combined_messages[-1] += message\n        else:\n            combined_messages.append(message)\n    return combined_messages\n"
  },
  {
    "path": "python/mlc_llm/protocol/debug_protocol.py",
    "content": "\"\"\"Debug protocols in MLC LLM\"\"\"\n\nfrom typing import Literal, Optional\n\nfrom pydantic import BaseModel\n\n\nclass DisaggConfig(BaseModel):\n    \"\"\"The class of metadata used in microserving APIs.\"\"\"\n\n    kind: Optional[Literal[\"prepare_receive\", \"remote_send\", \"start_generation\"]] = None\n    # \"kv_append_metadata\" is base64-encoded and is thus a string.\n    kv_append_metadata: Optional[str] = None\n    # \"kv_window_begin\" and \"kv_window_end\" denote the KV interval of interests.\n    # \"kv_window_end\" supports Python style negative indexing.\n    # The concrete meaning varies for different special request kind:\n    # - For \"prepare_receive\", the begin is always 0, and \"[0:end]\" denotes\n    # the KV range to prefill on a prefill instance.\n    # - For \"remote_send\", \"[begin:end]\" means the KV range to compute prefill\n    # and send to the decode instance.\n    # - For \"start_generation\", the end is always None, and \"[begin:]\" denotes\n    # the KV range to prefill locally on the decode instance.\n    kv_window_begin: Optional[int] = None\n    kv_window_end: Optional[int] = None\n    # KV data destination group offset\n    dst_group_offset: Optional[int] = None\n\n\nclass DebugConfig(BaseModel):\n    \"\"\"The class of debug options.\n\n    These optionals are available to engine\n    but won't be available to serving endpoint\n    unless an explicit --enable-debug passed\n    \"\"\"\n\n    ignore_eos: bool = False\n    pinned_system_prompt: bool = False\n    special_request: Optional[Literal[\"query_engine_metrics\"]] = None\n    grammar_execution_mode: Literal[\"constraint\", \"jump_forward\"] = \"jump_forward\"\n    disagg_config: Optional[DisaggConfig] = None\n\n    \"\"\"Special request indicators\n\n    Special requests are handled by engine differently and do not go\n    through the normal engine step flow.\n\n    The results to these requests are returned as field of \"usage\"\n    \"\"\"\n"
  },
  {
    "path": "python/mlc_llm/protocol/error_protocol.py",
    "content": "\"\"\"Error protocols in MLC LLM\"\"\"\n\nfrom http import HTTPStatus\nfrom typing import Optional\n\nimport fastapi\nfrom pydantic import BaseModel\n\n\nclass BadRequestError(ValueError):\n    \"\"\"The exception for bad requests in engines.\"\"\"\n\n    def __init__(self, *args: object) -> None:\n        super().__init__(*args)\n\n\nclass ErrorResponse(BaseModel):\n    \"\"\"The class of error response.\"\"\"\n\n    object: str = \"error\"\n    message: str\n    code: Optional[int] = None\n\n\ndef create_error_response(status_code: HTTPStatus, message: str) -> fastapi.responses.JSONResponse:\n    \"\"\"Create a JSON response that reports error with regarding the input message.\"\"\"\n    return fastapi.responses.JSONResponse(\n        ErrorResponse(message=message, code=status_code.value).model_dump_json(by_alias=True),\n        status_code=status_code.value,\n    )\n\n\nasync def bad_request_error_handler(_request: fastapi.Request, e: BadRequestError):\n    \"\"\"The handler of BadRequestError that converts an exception into error response.\"\"\"\n    return create_error_response(status_code=HTTPStatus.BAD_REQUEST, message=e.args[0])\n"
  },
  {
    "path": "python/mlc_llm/protocol/generation_config.py",
    "content": "\"\"\"Low-level generation config class\"\"\"\n\n# pylint: disable=missing-class-docstring, disable=too-many-instance-attributes\nfrom typing import Dict, List, Optional\n\nfrom pydantic import BaseModel\n\nfrom .debug_protocol import DebugConfig\nfrom .openai_api_protocol import RequestResponseFormat\n\n\nclass GenerationConfig(BaseModel):  # pylint:\n    \"\"\"The generation configuration dataclass.\n\n    This is a config class used by Engine internally.\n    \"\"\"\n\n    n: int = 1\n    temperature: Optional[float] = None\n    top_p: Optional[float] = None\n    frequency_penalty: Optional[float] = None\n    presence_penalty: Optional[float] = None\n    repetition_penalty: Optional[float] = None\n    logprobs: bool = False\n    top_logprobs: int = 0\n    logit_bias: Optional[Dict[int, float]] = None\n    # internally we use -1 to represent infinite\n    max_tokens: int = -1\n    seed: Optional[int] = None\n    stop_strs: Optional[List[str]] = None\n    stop_token_ids: Optional[List[int]] = None\n    response_format: Optional[RequestResponseFormat] = None\n    debug_config: Optional[Optional[DebugConfig]] = None\n"
  },
  {
    "path": "python/mlc_llm/protocol/microserving_protocol.py",
    "content": "\"\"\"Protocols in MLC LLM for MicroServing.\"\"\"\n\nfrom pydantic import BaseModel\n\nfrom mlc_llm.protocol.openai_api_protocol import CompletionRequest\n\n\nclass PrepRecvRequest(CompletionRequest):\n    \"\"\"The extra request body for prep_recv request in MicroServing.\n\n    Attributes\n    ----------\n    kv_window_end : int\n        [0, kv_window_end] denotes the KV range of the prompt to prefill on\n        a prefill instance.\n        The entries of this KV range will be allocated on the decode instance.\n    \"\"\"\n\n    end: int\n\n\nclass PrepRecvResponse(BaseModel):\n    \"\"\"The response body for prep_recv request in MicroServing.\n\n    Attributes\n    ----------\n    prefix_matched_length : int\n        The matched common prefix length on the decode instance when\n        prefix cache is enabled, or 0 if there is no prefix cache.\n\n    kv_append_metadata : str\n        The metadata of the KV range on the destination decode instance.\n    \"\"\"\n\n    kv_append_metadata: str\n    prefix_matched_length: int\n\n\nclass RemoteSendRequest(CompletionRequest):\n    \"\"\"The extra request body for remote_send request in MicroServing.\n\n    Attributes\n    ----------\n    kv_window_begin : int\n        Denote the start of the KV range to prefill.\n\n    kv_window_end : int\n        Denote the end of the KV range to prefill.\n\n    kv_append_metadata : str\n        The metadata of the KV range on the destination decode instance.\n\n    dst_group_offset : int\n        The node group offset of the destination decode instance.\n    \"\"\"\n\n    begin: int\n    end: int\n    kv_addr_info: str\n    recv_rank: int\n\n\nclass StartGenerateRequest(CompletionRequest):\n    \"\"\"The extra request body for start_generate request in MicroServing.\n\n    Attributes\n    ----------\n    kv_window_begin : int\n        Denote the start of the KV range to prefill on the decode instance.\n    \"\"\"\n\n    begin: int\n"
  },
  {
    "path": "python/mlc_llm/protocol/mlc_chat_config.py",
    "content": "# pylint: disable=too-many-instance-attributes\n\"\"\"Schema for mlc-chat-config\"\"\"\n\nfrom typing import Any, Dict, List, Literal, Optional, Union\n\nfrom pydantic import BaseModel, Field\n\nfrom mlc_llm.support.constants import MLC_CHAT_CONFIG_VERSION\n\nfrom .conversation_protocol import Conversation\n\nMLC_CHAT_SYSTEM_DEFAULT = {\n    \"pad_token_id\": 0,\n    \"bos_token_id\": 1,\n    \"eos_token_id\": 2,\n    \"temperature\": 1.0,\n    \"presence_penalty\": 0.0,\n    \"frequency_penalty\": 0.0,\n    \"repetition_penalty\": 1.0,\n    \"top_p\": 1.0,\n}\n\"\"\"system default values.\"\"\"\n\n\nclass MLCChatConfig(BaseModel):\n    \"\"\"Fields in the dumped `mlc-chat-config.json` file.\"\"\"\n\n    # Version control\n    version: str = MLC_CHAT_CONFIG_VERSION\n\n    # use alias to avoid protected namespace conflict with pydantic\n    field_model_type: str = Field(alias=\"model_type\")\n    quantization: str\n    # use alias to avoid protected namespace conflict with pydantic\n    field_model_config: Dict[str, Any] = Field(alias=\"model_config\")\n    vocab_size: int\n    context_window_size: int\n    sliding_window_size: int\n    prefill_chunk_size: int\n    attention_sink_size: int\n    tensor_parallel_shards: int\n    pipeline_parallel_stages: int = 1\n    # Configuration of text generation\n    active_vocab_size: int = None\n    temperature: Optional[float] = None\n    presence_penalty: Optional[float] = None\n    frequency_penalty: Optional[float] = None\n    repetition_penalty: Optional[float] = None\n    top_p: Optional[float] = None\n    # Tokenizer configuration\n    tokenizer_files: List[str] = Field(default_factory=list)\n    # The content of tokenizer.TokenizerInfo\n    tokenizer_info: Dict[str, Any] = Field(default_factory=dict)\n    # conversation template\n    conv_template: Conversation\n    # extra fields from generation_config.json\n    # NOTE: they are not being used for now in MLCEngine\n    # but we keep them for book-keep purposes\n    pad_token_id: Optional[int] = None\n    bos_token_id: Optional[int] = None\n    eos_token_id: Optional[Union[int, List[int]]] = None\n\n    model_task: Literal[\"chat\", \"embedding\"] = \"chat\"\n    embedding_metadata: Optional[Dict[str, Any]] = None\n\n    def get_system_defaults_for_missing_fields(self) -> Dict[str, Any]:\n        \"\"\"Apply system default value for fields that are None\n\n        Note\n        ----\n        We implement default setting in this way so we can lazily create\n        MLCChatConfig, override its optional values then\n        apply_system_defaults in the end.\n        \"\"\"\n        res = {}\n        for key, value in MLC_CHAT_SYSTEM_DEFAULT.items():\n            if getattr(self, key) is None:\n                res[key] = value\n        return res\n"
  },
  {
    "path": "python/mlc_llm/protocol/openai_api_protocol.py",
    "content": "\"\"\"Protocols in MLC LLM for OpenAI API.\nAdapted from FastChat's OpenAI protocol:\nhttps://github.com/lm-sys/FastChat/blob/main/fastchat/protocol/openai_api_protocol.py\n\"\"\"\n\n# pylint: disable=missing-class-docstring\n\nimport json\nimport time\nfrom typing import Any, Dict, List, Literal, Optional, Tuple, Union\n\nimport shortuuid\nfrom pydantic import BaseModel, Field, field_validator, model_validator\n\nfrom .conversation_protocol import Conversation\nfrom .debug_protocol import DebugConfig\nfrom .error_protocol import BadRequestError\n\n################ Commons ################\n\n\n# OPenAI API compatible limits\nCHAT_COMPLETION_MAX_TOP_LOGPROBS = 20\nCOMPLETION_MAX_TOP_LOGPROBS = 5\n\n\nclass ListResponse(BaseModel):\n    object: str = \"list\"\n    data: List[Any]\n\n\nclass TopLogProbs(BaseModel):\n    token: str\n    logprob: float\n    bytes: Optional[List[int]]\n\n\nclass LogProbsContent(BaseModel):\n    token: str\n    logprob: float\n    bytes: Optional[List[int]]\n    top_logprobs: List[TopLogProbs] = []\n\n\nclass LogProbs(BaseModel):\n    content: List[LogProbsContent]\n\n\nclass CompletionLogProbs(BaseModel):\n    # The position of the token in the concatenated str: prompt + completion_text\n    # TODO(vvchernov): skip optional after support\n    text_offset: Optional[List[int]]\n    token_logprobs: List[float]\n    tokens: List[str]\n    top_logprobs: List[Dict[str, float]]\n\n\nclass CompletionUsage(BaseModel):\n    prompt_tokens: int\n    completion_tokens: int\n    total_tokens: int\n    extra: Optional[Dict[str, Any]] = None\n    \"\"\"Extra metrics and info that may be returned by debug_config\n    \"\"\"\n\n\nclass StreamOptions(BaseModel):\n    include_usage: Optional[bool]\n\n\n################ v1/embeddings ################\n\n\nclass EmbeddingRequest(BaseModel):\n    \"\"\"OpenAI \"v1/embeddings\" request protocol.\n    API reference: https://platform.openai.com/docs/api-reference/embeddings/create\n    \"\"\"\n\n    input: Union[str, List[str], List[int], List[List[int]]]\n    model: Optional[str] = None\n    encoding_format: Literal[\"float\", \"base64\"] = \"float\"\n    dimensions: Optional[int] = None\n    user: Optional[str] = None\n\n    @field_validator(\"input\")\n    @classmethod\n    def validate_input(cls, v):\n        \"\"\"Check that the input is not an empty list.\n\n        Note: empty strings are allowed — encoder models produce valid\n        embeddings from [CLS]+[SEP] tokens alone.\n        \"\"\"\n        if isinstance(v, list) and len(v) == 0:\n            raise ValueError(\"Input list must not be empty.\")\n        return v\n\n\nclass EmbeddingObject(BaseModel):\n    object: str = \"embedding\"\n    embedding: Union[List[float], str]\n    index: int\n\n\nclass EmbeddingUsage(BaseModel):\n    prompt_tokens: int\n    total_tokens: int\n\n\nclass EmbeddingResponse(BaseModel):\n    \"\"\"OpenAI \"v1/embeddings\" response protocol.\n    API reference: https://platform.openai.com/docs/api-reference/embeddings/object\n    \"\"\"\n\n    object: str = \"list\"\n    data: List[EmbeddingObject]\n    model: Optional[str] = None\n    usage: EmbeddingUsage\n\n\n################ v1/models ################\n\n\nclass ModelResponse(BaseModel):\n    \"\"\"OpenAI \"v1/models\" response protocol.\n    API reference: https://platform.openai.com/docs/api-reference/models/object\n    \"\"\"\n\n    id: str\n    created: int = Field(default_factory=lambda: int(time.time()))\n    object: str = \"model\"\n    owned_by: str = \"MLC-LLM\"\n\n\n################ v1/completions ################\n\n\nclass RequestResponseFormat(BaseModel):\n    type: Literal[\"text\", \"json_object\"] = \"text\"\n    json_schema: Optional[str] = Field(default=None, alias=\"schema\")\n    \"\"\"This field is named json_schema instead of schema because BaseModel defines a method called\n    schema. During construction of RequestResponseFormat, key \"schema\" still should be used:\n    `RequestResponseFormat(type=\"json_object\", schema=\"{}\")`\n    \"\"\"\n\n\nclass CompletionRequest(BaseModel):\n    \"\"\"OpenAI completion request protocol.\n    API reference: https://platform.openai.com/docs/api-reference/completions/create\n    \"\"\"\n\n    model: Optional[str] = None\n    prompt: Union[str, List[int]]\n    best_of: int = 1\n    echo: bool = False\n    frequency_penalty: Optional[float] = None\n    presence_penalty: Optional[float] = None\n    logprobs: Optional[int] = None\n    logit_bias: Optional[Dict[int, float]] = None\n    max_tokens: Optional[int] = None\n    n: int = 1\n    seed: Optional[int] = None\n    stop: Optional[Union[str, List[str]]] = None\n    stream: bool = False\n    stream_options: Optional[StreamOptions] = None\n    suffix: Optional[str] = None\n    temperature: Optional[float] = None\n    top_p: Optional[float] = None\n    user: Optional[str] = None\n    response_format: Optional[RequestResponseFormat] = None\n    debug_config: Optional[DebugConfig] = None\n\n    @field_validator(\"frequency_penalty\", \"presence_penalty\")\n    @classmethod\n    def check_penalty_range(cls, penalty_value: Optional[float]) -> Optional[float]:\n        \"\"\"Check if the penalty value is in range [-2, 2].\"\"\"\n        if penalty_value and (penalty_value < -2 or penalty_value > 2):\n            raise ValueError(\"Penalty value should be in range [-2, 2].\")\n        return penalty_value\n\n    @field_validator(\"logit_bias\")\n    @classmethod\n    def check_logit_bias(\n        cls, logit_bias_value: Optional[Dict[int, float]]\n    ) -> Optional[Dict[int, float]]:\n        \"\"\"Check if the logit bias key is given as an integer.\"\"\"\n        if logit_bias_value is None:\n            return None\n        for token_id, bias in logit_bias_value.items():\n            if abs(bias) > 100:\n                raise ValueError(\n                    \"Logit bias value should be in range [-100, 100], while value \"\n                    f\"{bias} is given for token id {token_id}\"\n                )\n        return logit_bias_value\n\n    @model_validator(mode=\"after\")\n    def check_logprobs(self) -> \"CompletionRequest\":\n        \"\"\"Check if the logprobs requirements are valid.\"\"\"\n        if self.logprobs is not None and (\n            self.logprobs < 0 or self.logprobs > COMPLETION_MAX_TOP_LOGPROBS\n        ):\n            raise ValueError(f'\"logprobs\" must be in range [0, {COMPLETION_MAX_TOP_LOGPROBS}]')\n        return self\n\n\nclass CompletionResponseChoice(BaseModel):\n    finish_reason: Optional[Literal[\"stop\", \"length\", \"preempt\"]] = None\n    index: int = 0\n    logprobs: Optional[CompletionLogProbs] = None\n    text: str\n\n\nclass CompletionResponse(BaseModel):\n    \"\"\"OpenAI completion response protocol.\n    API reference: https://platform.openai.com/docs/api-reference/completions/object\n    \"\"\"\n\n    id: str\n    choices: List[CompletionResponseChoice]\n    created: int = Field(default_factory=lambda: int(time.time()))\n    model: Optional[str] = None\n    object: str = \"text_completion\"\n    usage: Optional[CompletionUsage] = None\n\n\n################ v1/chat/completions ################\n\n\nclass ChatFunction(BaseModel):\n    description: Optional[str] = None\n    name: str\n    parameters: Dict\n\n\nclass ChatTool(BaseModel):\n    type: Literal[\"function\"]\n    function: ChatFunction\n\n\nclass ChatFunctionCall(BaseModel):\n    name: str\n    arguments: Union[None, Dict[str, Any]] = None\n\n\nclass ChatToolCall(BaseModel):\n    id: str = Field(default_factory=lambda: f\"call_{shortuuid.random()}\")\n    type: Literal[\"function\"]\n    function: ChatFunctionCall\n\n\nclass ChatCompletionMessage(BaseModel):\n    content: Optional[Union[str, List[Dict]]] = None\n    role: Literal[\"system\", \"user\", \"assistant\", \"tool\"]\n    name: Optional[str] = None\n    tool_calls: Optional[List[ChatToolCall]] = None\n    tool_call_id: Optional[str] = None\n\n\nclass ChatCompletionRequest(BaseModel):\n    \"\"\"OpenAI chat completion request protocol.\n    API reference: https://platform.openai.com/docs/api-reference/chat/create\n    \"\"\"\n\n    messages: List[ChatCompletionMessage]\n    model: Optional[str] = None\n    frequency_penalty: Optional[float] = None\n    presence_penalty: Optional[float] = None\n    logprobs: bool = False\n    top_logprobs: int = 0\n    logit_bias: Optional[Dict[int, float]] = None\n    max_tokens: Optional[int] = None\n    n: int = 1\n    seed: Optional[int] = None\n    stop: Optional[Union[str, List[str]]] = None\n    stream: bool = False\n    stream_options: Optional[StreamOptions] = None\n    temperature: Optional[float] = None\n    top_p: Optional[float] = None\n    tools: Optional[List[ChatTool]] = None\n    tool_choice: Optional[Union[Literal[\"none\", \"auto\"], Dict]] = None\n    user: Optional[str] = None\n    response_format: Optional[RequestResponseFormat] = None\n    # NOTE: debug_config is not part of OpenAI protocol\n    # we add it to enable extra debug options\n    debug_config: Optional[DebugConfig] = None\n\n    @field_validator(\"frequency_penalty\", \"presence_penalty\")\n    @classmethod\n    def check_penalty_range(cls, penalty_value: Optional[float]) -> Optional[float]:\n        \"\"\"Check if the penalty value is in range [-2, 2].\"\"\"\n        if penalty_value and (penalty_value < -2 or penalty_value > 2):\n            raise ValueError(\"Penalty value should be in range [-2, 2].\")\n        return penalty_value\n\n    @field_validator(\"logit_bias\")\n    @classmethod\n    def check_logit_bias(\n        cls, logit_bias_value: Optional[Dict[int, float]]\n    ) -> Optional[Dict[int, float]]:\n        \"\"\"Check if the logit bias key is given as an integer.\"\"\"\n        if logit_bias_value is None:\n            return None\n        for token_id, bias in logit_bias_value.items():\n            if abs(bias) > 100:\n                raise ValueError(\n                    \"Logit bias value should be in range [-100, 100], while value \"\n                    f\"{bias} is given for token id {token_id}\"\n                )\n        return logit_bias_value\n\n    @model_validator(mode=\"after\")\n    def check_logprobs(self) -> \"ChatCompletionRequest\":\n        \"\"\"Check if the logprobs requirements are valid.\"\"\"\n        if self.top_logprobs < 0 or self.top_logprobs > CHAT_COMPLETION_MAX_TOP_LOGPROBS:\n            raise ValueError(\n                f'\"top_logprobs\" must be in range [0, {CHAT_COMPLETION_MAX_TOP_LOGPROBS}]'\n            )\n        if not self.logprobs and self.top_logprobs > 0:\n            raise ValueError('\"logprobs\" must be True to support \"top_logprobs\"')\n        return self\n\n    @model_validator(mode=\"after\")\n    def check_stream_options(self) -> \"ChatCompletionRequest\":\n        \"\"\"Check stream options\"\"\"\n        if self.stream_options is None:\n            return self\n        if not self.stream:\n            raise ValueError(\"stream must be set to True when stream_options is present\")\n        return self\n\n    @model_validator(mode=\"after\")\n    def check_debug_config(self) -> \"ChatCompletionRequest\":\n        \"\"\"Check debug config\"\"\"\n        if self.debug_config is None:\n            return self\n\n        if self.debug_config.special_request is None:\n            return self\n\n        if not self.stream:\n            raise ValueError(\"DebugConfig.special_request requires stream=True\")\n\n        if self.stream_options is None or not self.stream_options.include_usage:\n            raise ValueError(\"DebugConfig.special_request requires include_usage in stream_options\")\n\n        return self\n\n    def check_message_validity(self) -> None:\n        \"\"\"Check if the given chat messages are valid. Return error message if invalid.\"\"\"\n        for i, message in enumerate(self.messages):\n            if message.role == \"system\" and i != 0:\n                raise BadRequestError(\n                    f\"System prompt at position {i} in the message list is invalid.\"\n                )\n            if message.tool_call_id is not None:\n                if message.role != \"tool\":\n                    raise BadRequestError(\"Non-tool message having `tool_call_id` is invalid.\")\n            if isinstance(message.content, list):\n                if message.role != \"user\":\n                    raise BadRequestError(\"Non-user message having a list of content is invalid.\")\n            if message.tool_calls is not None:\n                if message.role != \"assistant\":\n                    raise BadRequestError(\"Non-assistant message having `tool_calls` is invalid.\")\n                raise BadRequestError(\"Assistant message having `tool_calls` is not supported yet.\")\n\n    def check_function_call_usage(self, conv_template: Conversation) -> None:\n        \"\"\"Check if function calling is used and update the conversation template.\n        Return error message if invalid request format for function calling.\n        \"\"\"\n\n        # return if no tools are provided or tool_choice is set to none\n        if self.tools is None or (isinstance(self.tool_choice, str) and self.tool_choice == \"none\"):\n            conv_template.use_function_calling = False\n            return\n\n        # select the tool based on the tool_choice if specified\n        if isinstance(self.tool_choice, dict):\n            if self.tool_choice[\"type\"] != \"function\":  # pylint: disable=unsubscriptable-object\n                raise BadRequestError(\"Only 'function' tool choice is supported\")\n\n            if len(self.tool_choice[\"function\"]) > 1:  # pylint: disable=unsubscriptable-object\n                raise BadRequestError(\"Only one tool is supported when tool_choice is specified\")\n\n            for tool in self.tools:  # pylint: disable=not-an-iterable\n                if (\n                    tool.function.name\n                    == self.tool_choice[\"function\"][  # pylint: disable=unsubscriptable-object\n                        \"name\"\n                    ]\n                ):\n                    conv_template.use_function_calling = True\n                    conv_template.function_string = tool.function.model_dump_json(by_alias=True)\n                    return\n\n            # pylint: disable=unsubscriptable-object\n            raise BadRequestError(\n                f\"The tool_choice function {self.tool_choice['function']['name']}\"\n                \" is not found in the tools list\"\n            )\n            # pylint: enable=unsubscriptable-object\n\n        if isinstance(self.tool_choice, str) and self.tool_choice != \"auto\":\n            raise BadRequestError(f\"Invalid tool_choice value: {self.tool_choice}\")\n\n        function_list = []\n        for tool in self.tools:  # pylint: disable=not-an-iterable\n            if tool.type != \"function\":\n                raise BadRequestError(\"Only 'function' tool type is supported\")\n            function_list.append(tool.function.model_dump(by_alias=True))\n\n        conv_template.use_function_calling = True\n        conv_template.function_string = json.dumps(function_list)\n\n\nclass ChatCompletionResponseChoice(BaseModel):\n    finish_reason: Optional[Literal[\"stop\", \"length\", \"tool_calls\", \"error\"]] = None\n    index: int = 0\n    message: ChatCompletionMessage\n    logprobs: Optional[LogProbs] = None\n\n\nclass ChatCompletionStreamResponseChoice(BaseModel):\n    finish_reason: Optional[Literal[\"stop\", \"length\", \"tool_calls\", \"error\"]] = None\n    index: int = 0\n    delta: ChatCompletionMessage\n    logprobs: Optional[LogProbs] = None\n\n\nclass ChatCompletionResponse(BaseModel):\n    \"\"\"OpenAI completion response protocol.\n    API reference: https://platform.openai.com/docs/api-reference/chat/object\n    \"\"\"\n\n    id: str\n    choices: List[ChatCompletionResponseChoice]\n    created: int = Field(default_factory=lambda: int(time.time()))\n    model: Optional[str] = None\n    system_fingerprint: str\n    object: Literal[\"chat.completion\"] = \"chat.completion\"\n    usage: Optional[CompletionUsage] = None\n\n\nclass ChatCompletionStreamResponse(BaseModel):\n    \"\"\"OpenAI completion stream response protocol.\n    API reference: https://platform.openai.com/docs/api-reference/chat/streaming\n    \"\"\"\n\n    id: str\n    choices: List[ChatCompletionStreamResponseChoice]\n    created: int = Field(default_factory=lambda: int(time.time()))\n    model: Optional[str] = None\n    system_fingerprint: str\n    object: Literal[\"chat.completion.chunk\"] = \"chat.completion.chunk\"\n    usage: Optional[CompletionUsage] = None\n\n\n################################################\n\n\ndef openai_api_get_unsupported_fields(\n    request: Union[CompletionRequest, ChatCompletionRequest],\n) -> List[str]:\n    \"\"\"Get the unsupported fields in the request.\"\"\"\n    unsupported_field_default_values: List[Tuple[str, Any]] = [\n        (\"best_of\", 1),\n    ]\n\n    unsupported_fields: List[str] = []\n    for field, value in unsupported_field_default_values:\n        if hasattr(request, field) and getattr(request, field) != value:\n            unsupported_fields.append(field)\n    return unsupported_fields\n"
  },
  {
    "path": "python/mlc_llm/quantization/__init__.py",
    "content": "\"\"\"A subpackage for quantization and dequantization algorithms\"\"\"\n\nfrom .awq_quantization import AWQQuantize\nfrom .block_scale_quantization import BlockScaleQuantize\nfrom .fp8_quantization import FP8PerTensorQuantizeMixtralExperts\nfrom .ft_quantization import FTQuantize\nfrom .group_quantization import GroupQuantize\nfrom .model_quantization import make_awq_quant, make_quantization_functions\nfrom .no_quantization import NoQuantize\nfrom .per_tensor_quantization import PerTensorQuantize\nfrom .quantization import QUANTIZATION, Quantization\n"
  },
  {
    "path": "python/mlc_llm/quantization/awq_quantization.py",
    "content": "\"\"\"AWQ Quantization\"\"\"\n\nfrom dataclasses import dataclass, field\nfrom typing import Any, Callable, Dict, List, Optional\n\nfrom tvm import DataType, DataTypeCode, te, tir, topi\nfrom tvm.relax.frontend import nn\nfrom tvm.runtime import Tensor\n\nfrom mlc_llm.loader import QuantizeMapping\n\nfrom .utils import convert_uint_to_float, is_final_fc, is_moe_gate\n\n\ndef _make_divisible(c, divisor):  # pylint: disable=invalid-name\n    return (c + divisor - 1) // divisor\n\n\ndef _calculate_zeros_width(in_features, group_size=128, pack_num=8):\n    if group_size >= 128:\n        size_multiplier = 1\n    elif group_size == 64:\n        size_multiplier = 2\n    elif group_size == 32:\n        size_multiplier = 4\n    else:\n        raise NotImplementedError\n\n    base_width = _make_divisible(in_features // group_size, pack_num)\n    base_width = _make_divisible(base_width, size_multiplier) * size_multiplier\n    return base_width\n\n\n@dataclass\nclass AWQQuantize:  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration for AWQ quantization\"\"\"\n\n    name: str\n    kind: str\n    group_size: int\n    quantize_dtype: str  # \"int3\", \"int4\", \"int8\"\n    storage_dtype: str  # \"uint32\"\n    model_dtype: str  # \"float16\", \"float32\"\n\n    num_elem_per_storage: int = 0\n    num_storage_per_group: int = 0\n    max_int_value: int = 0\n\n    prebuilt_quantize_func: Dict[str, Callable[[Tensor], Tensor]] = field(\n        default_factory=lambda: {}\n    )\n\n    def __post_init__(self):\n        assert self.kind == \"awq\"\n        quantize_dtype = DataType(self.quantize_dtype)\n        storage_dtype = DataType(self.storage_dtype)\n        model_dtype = DataType(self.model_dtype)\n        assert quantize_dtype.type_code == DataTypeCode.INT\n        assert storage_dtype.type_code == DataTypeCode.UINT\n        assert model_dtype.type_code == DataTypeCode.FLOAT\n        if storage_dtype.bits < quantize_dtype.bits:\n            raise ValueError(\"Storage unit should be greater or equal to quantized element\")\n\n        self.num_elem_per_storage = storage_dtype.bits // quantize_dtype.bits\n        if self.group_size % self.num_elem_per_storage != 0:\n            raise ValueError(\"Group size should be divisible by numbers of elements per storage\")\n        self.num_storage_per_group = self.group_size // self.num_elem_per_storage\n        self.max_int_value = (2 ** (quantize_dtype.bits - 1)) - 1\n\n    def quantize_model(\n        self,\n        model: nn.Module,\n        quant_map: QuantizeMapping,\n        name_prefix: str,\n    ) -> nn.Module:\n        \"\"\"\n        Quantize model with awq quantization.\n\n        Parameters\n        ----------\n        model : nn.Module\n            The non-quantized nn.Module.\n\n        quant_map : QuantizeMapping\n            The quantize mapping with name mapping and func mapping.\n\n        name_prefix : str\n            The name prefix for visited weight.\n\n        Returns\n        -------\n        ret : nn.Module\n            The quantized nn.Module.\n        \"\"\"\n\n        class _Mutator(nn.Mutator):\n            def __init__(self, config: AWQQuantize, quant_map: QuantizeMapping) -> None:\n                super().__init__()\n                self.config = config\n                self.quant_map = quant_map\n\n            def visit_module(self, name: str, node: nn.Module) -> Any:\n                \"\"\"\n                The visiting method for awq quantization of nn.Module nodes.\n\n                Parameters\n                ----------\n                name : str\n                    The name of the current node\n\n                node : nn.Module\n                    The current node of nn.Module to mutate.\n\n                Returns\n                -------\n                ret_node : Any\n                    The new node to replace current node.\n                \"\"\"\n\n                if (\n                    isinstance(node, nn.Linear)\n                    and not is_final_fc(name)\n                    and not is_moe_gate(name, node)\n                ):\n                    return AWQQuantizeLinear.from_linear(node, self.config)\n                return self.visit(name, node)\n\n        model.to(dtype=self.model_dtype)\n        mutator = _Mutator(self, quant_map)\n        model = mutator.visit(name_prefix, model)\n        return model\n\n    def _dequantize(\n        self,\n        weight: te.Tensor,\n        zeros: te.Tensor,\n        scale: te.Tensor,\n        out_shape: Optional[List[tir.PrimExpr]] = None,\n    ):\n        float_weight = convert_uint_to_float(\n            weight,\n            DataType(self.quantize_dtype).bits,\n            self.num_elem_per_storage,\n            self.storage_dtype,\n            self.model_dtype,\n            out_shape=[weight.shape[0], weight.shape[1] * self.num_elem_per_storage],\n            ft_reorder=True,\n        )\n        float_zeros = convert_uint_to_float(\n            zeros,\n            DataType(self.quantize_dtype).bits,\n            self.num_elem_per_storage,\n            self.storage_dtype,\n            self.model_dtype,\n            out_shape=[zeros.shape[0], zeros.shape[1] * self.num_elem_per_storage],\n            ft_reorder=True,\n        )\n        float_weight = topi.transpose(float_weight)\n        float_zeros = topi.transpose(float_zeros)\n        scale = topi.transpose(scale)\n        return te.compute(\n            shape=(\n                [weight.shape[0], weight.shape[1] * self.num_elem_per_storage]\n                if out_shape is None\n                else out_shape\n            ),\n            fcompute=lambda i, j: tir.multiply(\n                tir.subtract(float_weight[i, j], float_zeros[i, j // self.group_size]),\n                scale[i, j // self.group_size],\n            ),\n            name=\"dequantize\",\n        )\n\n\nclass AWQQuantizeLinear(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"An nn.Linear module with AWQ quantization\"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments\n        self,\n        in_features: int,\n        out_features: int,\n        config: AWQQuantize,\n        bias: bool = True,\n        out_dtype: Optional[str] = None,\n    ) -> None:\n        super().__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.out_dtype = out_dtype\n        self.config = config\n        self.qweight = nn.Parameter(\n            (in_features, out_features // config.num_elem_per_storage),\n            config.storage_dtype,\n        )\n        self.qzeros = nn.Parameter(\n            (\n                in_features // config.group_size,\n                out_features // config.num_elem_per_storage,\n            ),\n            config.storage_dtype,\n        )\n        self.scales = nn.Parameter(\n            (in_features // config.group_size, out_features), config.model_dtype\n        )\n        if bias:\n            self.bias = nn.Parameter(\n                (out_features,), config.model_dtype if out_dtype is None else out_dtype\n            )\n        else:\n            self.bias = None\n\n    @staticmethod\n    def from_linear(linear: nn.Linear, config: AWQQuantize) -> \"AWQQuantizeLinear\":\n        \"\"\"\n        Converts a non-quantized nn.Linear to a group quantized AWQQuantizeLinear\n\n        Parameters\n        ----------\n        linear : nn.Linear\n            The non-quantized nn.Linear.\n\n        config : AWQQuantize\n            The awq quantization config.\n\n        Returns\n        -------\n        ret : GroupQuantizeLinear\n            The awq quantized AWQQuantizeLinear layer.\n        \"\"\"\n        return AWQQuantizeLinear(\n            in_features=linear.in_features,\n            out_features=linear.out_features,\n            config=config,\n            bias=getattr(linear, \"bias\", None) is not None,\n            out_dtype=linear.out_dtype,\n        )\n\n    def forward(self, x: nn.Tensor) -> nn.Tensor:  # pylint: disable=invalid-name\n        \"\"\"\n        Forward method for awq quantized linear layer\n\n        Parameters\n        ----------\n        x : nn.Tensor\n            The input tensor.\n\n        Returns\n        -------\n        ret : nn.Tensor\n            The output tensor for the group quantized linear layer.\n        \"\"\"\n        w = nn.op.tensor_expr_op(  # pylint: disable=invalid-name\n            lambda weight, zeros, scale: self.config._dequantize(  # pylint: disable=protected-access\n                weight,\n                zeros,\n                scale,\n                [\n                    tir.IntImm(\"int64\", self.out_features),\n                    tir.IntImm(\"int64\", self.in_features),\n                ],\n            ),\n            name_hint=\"dequantize\",\n            args=[self.qweight, self.qzeros, self.scales],\n        )\n        w = nn.op.permute_dims(w)  # pylint: disable=invalid-name\n        x = nn.op.matmul(x, w, out_dtype=self.out_dtype)\n        if self.bias is not None:\n            x = x + self.bias\n        return x\n\n    def to(self, dtype: Optional[str] = None) -> None:\n        \"\"\"\n        Override to() such that we do not convert bias if there is an out_dtype.\n        Otherwise, we might run into dtype mismatch when computing x + self.bias.\n        \"\"\"\n        self.qweight.to(dtype=dtype)\n        self.qzeros.to(dtype=dtype)\n        self.scales.to(dtype=dtype)\n        if self.bias is not None and self.out_dtype is None:\n            self.bias.to(dtype=dtype)\n        if dtype is not None and isinstance(getattr(self, \"dtype\", None), str):\n            self.dtype = dtype  # pylint: disable=attribute-defined-outside-init\n"
  },
  {
    "path": "python/mlc_llm/quantization/block_scale_quantization.py",
    "content": "\"\"\"The block-scale quantization config\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Any, Literal, Optional, Tuple\n\nimport tvm\nfrom tvm import DataType, DataTypeCode, te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.script import tir as T\n\nfrom mlc_llm.loader import QuantizeMapping\nfrom mlc_llm.nn import MixtralExperts\nfrom mlc_llm.op import cutlass, extern, moe_matmul, triton\nfrom mlc_llm.support import logging\nfrom mlc_llm.support import tensor_parallel as tp\n\nfrom .utils import apply_sharding, is_final_fc, is_moe_gate\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass BlockScaleQuantize:  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration for block-scale quantization\"\"\"\n\n    name: str\n    kind: str = \"block-scale\"\n    weight_dtype: Literal[\"float8_e4m3fn\", \"float8_e5m2\"] = \"float8_e4m3fn\"\n    model_dtype: Literal[\"float16\", \"bfloat16\"] = \"bfloat16\"\n    quantize_linear: bool = True\n    weight_block_size: Optional[Tuple[int, int]] = None\n    use_activation_scale: bool = False\n\n    def __post_init__(self):\n        assert self.kind == \"block-scale-quant\"\n        weight_dtype = DataType(self.weight_dtype)\n        model_dtype = DataType(self.model_dtype)\n        assert weight_dtype.type_code in [\n            DataTypeCode.Float8E4M3FN,\n            DataTypeCode.Float8E5M2,\n        ]\n        assert model_dtype.type_code in [\n            DataTypeCode.FLOAT,\n            DataTypeCode.BFLOAT,\n        ]\n\n    def quantize_model(\n        self,\n        model: nn.Module,\n        quant_map: QuantizeMapping,\n        name_prefix: str,\n    ) -> nn.Module:\n        \"\"\"Quantize model with block-scale quantization\n\n        Parameters\n        ----------\n        model : nn.Module\n            The non-quantized nn.Module.\n\n        quant_map : QuantizeMapping\n            The quantize mapping with name mapping and func mapping.\n\n        name_prefix : str\n            The name prefix for visited weight.\n\n        Returns\n        -------\n        ret : nn.Module\n            The quantized nn.Module.\n        \"\"\"\n\n        weight_block_size = model.weight_block_size\n\n        class _Mutator(nn.Mutator):\n            def __init__(self, config: BlockScaleQuantize, quant_map: QuantizeMapping) -> None:\n                super().__init__()\n                self.config = config\n                self.quant_map = quant_map\n\n            def visit_module(self, name: str, node: nn.Module) -> Any:\n                \"\"\"The visiting method for block-scale quantization of nn.Module nodes.\n\n                Parameters\n                ----------\n                name : str\n                    The name of the current node.\n\n                node : nn.Module\n                    The current node of nn.Module to mutate.\n\n                Returns\n                ------\n                ret : Any\n                \"\"\"\n                if getattr(node, \"no_quantization\", False):\n                    return node\n\n                if hasattr(node, \"w_uk\"):\n                    assert hasattr(node, \"w_uv\")\n                    assert node.block_size == weight_block_size\n                    if (\n                        node.qk_nope_head_dim % node.block_size[0] != 0\n                        or node.v_head_dim % node.block_size[1] != 0\n                    ):\n                        raise ValueError(\n                            \"Invalid DeepSeek model config: \"\n                            \"qk_nope_head_dim must be multiple of weight_block_size[0], and \"\n                            \"v_head_dim must be multiple of weight_block_size[1]. \"\n                            f\"However, qk_nope_head_dim is {node.qk_nope_head_dim}, \"\n                            f\"v_head_dim is {node.v_head_dim}, \"\n                            f\"weight_block_size is {node.block_size}.\"\n                        )\n                    w_uk_shard_strategy = node.w_uk.attrs.get(\"shard_strategy\", None)\n                    w_uv_shard_strategy = node.w_uv.attrs.get(\"shard_strategy\", None)\n                    node.w_uk = nn.Parameter(\n                        (node.num_heads, node.kv_lora_rank, node.qk_nope_head_dim),\n                        self.config.weight_dtype,\n                    )\n                    node.w_uv = nn.Parameter(\n                        (node.num_heads, node.v_head_dim, node.kv_lora_rank),\n                        self.config.weight_dtype,\n                    )\n                    node.w_uk_scale_inv = nn.Parameter(\n                        (\n                            node.num_heads,\n                            node.kv_lora_rank // node.block_size[1],\n                            node.qk_nope_head_dim // node.block_size[0],\n                        ),\n                        \"float32\",\n                    )\n                    node.w_uv_scale_inv = nn.Parameter(\n                        (\n                            node.num_heads,\n                            node.v_head_dim // node.block_size[0],\n                            node.kv_lora_rank // node.block_size[1],\n                        ),\n                        \"float32\",\n                    )\n                    if w_uk_shard_strategy is not None:\n                        assert w_uk_shard_strategy.segs is None\n                        apply_sharding(w_uk_shard_strategy, w_uk_shard_strategy.name, node.w_uk)\n                        apply_sharding(\n                            w_uk_shard_strategy,\n                            f\"{w_uk_shard_strategy.name}_scale_inv\",\n                            node.w_uk_scale_inv,\n                        )\n                    if w_uv_shard_strategy is not None:\n                        assert w_uv_shard_strategy.segs is None\n                        apply_sharding(w_uv_shard_strategy, w_uv_shard_strategy.name, node.w_uv)\n                        apply_sharding(\n                            w_uv_shard_strategy,\n                            f\"{w_uv_shard_strategy.name}_scale_inv\",\n                            node.w_uv_scale_inv,\n                        )\n\n                if (\n                    isinstance(node, nn.Linear)\n                    and not is_final_fc(name)\n                    and not is_moe_gate(name, node)\n                ):\n                    if self.config.use_activation_scale:\n                        return BlockScaleQuantizeLinearStaticActivation.from_linear(\n                            node, self.config, weight_block_size\n                        )\n                    return BlockScaleQuantizeLinear.from_linear(\n                        node, self.config, weight_block_size\n                    )\n                if isinstance(node, MixtralExperts):\n                    return BlockScaleQuantizeMixtralExperts.from_mixtral_experts(\n                        node, self.config, weight_block_size\n                    )\n                return self.visit(name, node)\n\n        model.to(dtype=self.model_dtype)\n        mutator = _Mutator(self, quant_map)\n        model = mutator.visit(name_prefix, model)\n        self.weight_block_size = weight_block_size\n        return model\n\n\nclass BlockScaleQuantizeLinear(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Block-scale quantization for Linear\"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments\n        self,\n        in_features: int,\n        out_features: int,\n        weight_dtype: Literal[\"float8_e4m3fn\", \"float8_e5m2\"],\n        block_size: Tuple[int, int],\n        bias: bool = True,\n        dtype: Optional[str] = None,\n        out_dtype: Optional[str] = None,\n    ) -> None:\n        super().__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.out_dtype = out_dtype\n        self.weight = nn.Parameter((out_features, in_features), weight_dtype)\n        self.weight_scale_inv = nn.Parameter(\n            (\n                (out_features + block_size[0] - 1) // block_size[0],\n                (in_features + block_size[1] - 1) // block_size[1],\n            ),\n            \"float32\",\n        )\n        self.weight_dtype = weight_dtype\n        self.block_size = block_size\n        if bias:\n            self.bias = nn.Parameter((out_features,), dtype if out_dtype is None else out_dtype)\n        else:\n            self.bias = None\n\n    @staticmethod\n    def from_linear(\n        src: nn.Linear,\n        config: BlockScaleQuantize,\n        weight_block_size: Optional[Tuple[int, int]],\n    ) -> \"BlockScaleQuantizeLinear\":\n        \"\"\"\n        Converts a non-quantized nn.Linear to a block-scale quantized BlockScaleQuantizeLinear\n\n        Parameters\n        ----------\n        src : nn.Linear\n            The non-quantized nn.Linear.\n\n        config : BlockScaleQuantize\n            The block-scale quantization config.\n\n        weight_block_size : Optional[Tuple[int, int]]\n            The weight block size.\n\n        Returns\n        -------\n        ret : BlockScaleQuantizeLinear\n            The block-scale quantized BlockScaleQuantizeLinear.\n        \"\"\"\n        assert weight_block_size is not None\n        out_features, in_features = src.weight.shape\n        quantized_linear = BlockScaleQuantizeLinear(\n            in_features=in_features,\n            out_features=out_features,\n            weight_dtype=config.weight_dtype,\n            block_size=weight_block_size,\n            bias=getattr(src, \"bias\", None) is not None,\n            dtype=config.model_dtype,\n            out_dtype=src.out_dtype,\n        )\n        if quantized_linear.bias is not None:\n            quantized_linear.bias.attrs = src.bias.attrs\n        if \"shard_strategy\" in src.weight.attrs:\n            shard = src.weight.attrs[\"shard_strategy\"]\n            apply_sharding(shard, shard.name, quantized_linear.weight)\n            if isinstance(shard, tp.ShardSingleDim) and shard.segs is not None:\n                shard.segs = [x // weight_block_size[shard.dim] for x in shard.segs]\n            apply_sharding(shard, f\"{shard.name}_scale_inv\", quantized_linear.weight_scale_inv)\n        return quantized_linear\n\n    def forward(self, x: nn.Tensor) -> nn.Tensor:\n        \"\"\"Forward pass of the block-scale quantized linear layer.\n\n        Parameters\n        ----------\n        x : nn.Tensor\n            The input tensor.\n\n        Returns\n        -------\n        ret : nn.Tensor\n            The output tensor.\n        \"\"\"\n        m = 1\n        for i in range(x.ndim - 1):\n            m *= x.shape[i]\n        if m == 1:\n            x_shape = x.shape\n            return dequantize_float8_groupwise_scaled_gemv(\n                x.reshape(1, x.shape[-1]),\n                self.weight,\n                self.weight_scale_inv,\n                self.block_size,\n                self.out_dtype if self.out_dtype is not None else x.dtype,\n            ).reshape(*x_shape[:-1], -1)\n\n        shape_supported_by_cutlass = (  # pylint: disable=unused-variable\n            self.weight.shape[0] % 128 == 0 and self.weight.shape[1] % 128 == 0\n        )\n        # Todo: check \"shape supported by cutlass\" for Hopper  # pylint: disable=fixme\n        if (\n            extern.get_store().cutlass_gemm\n            and tvm.get_global_func(\n                \"cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn\", allow_missing=True\n            )\n            is not None\n        ):\n            x_fp8, x_scale = rowwise_group_quant_fp8(\n                x, self.block_size[1], self.weight_dtype, transpose_scale=True\n            )\n            x = cutlass.fp8_groupwise_scaled_gemm(\n                x_fp8,\n                x_scale,\n                self.weight,\n                self.weight_scale_inv,\n                self.block_size,\n                self.out_dtype if self.out_dtype is not None else x.dtype,\n            )\n        else:\n            x_fp8, x_scale = rowwise_group_quant_fp8(\n                x, self.block_size[1], self.weight_dtype, transpose_scale=False\n            )\n            x = triton.fp8_groupwise_scaled_gemm(\n                x_fp8,\n                x_scale,\n                self.weight,\n                self.weight_scale_inv,\n                self.block_size,\n                self.out_dtype if self.out_dtype is not None else x.dtype,\n            )\n        if self.bias is not None:\n            x = x + self.bias\n        return x\n\n    def to(self, dtype: Optional[str] = None) -> None:\n        \"\"\"\n        Override to() such that we do not convert bias if there is an out_dtype.\n        Otherwise, we might run into dtype mismatch when computing x + self.bias.\n        \"\"\"\n        if self.bias is not None and self.out_dtype is None:\n            self.bias.to(dtype=dtype)\n        if dtype is not None and isinstance(getattr(self, \"dtype\", None), str):\n            self.dtype = dtype  # pylint: disable=attribute-defined-outside-init\n\n\nclass BlockScaleQuantizeLinearStaticActivation(BlockScaleQuantizeLinear):\n    \"\"\"Block-scale quantization for static activation FP8.\"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments\n        self,\n        in_features: int,\n        out_features: int,\n        weight_dtype: Literal[\"float8_e4m3fn\", \"float8_e5m2\"],\n        block_size: Tuple[int, int],\n        bias: bool = True,\n        dtype: Optional[str] = None,\n        out_dtype: Optional[str] = None,\n    ) -> None:\n        super().__init__(\n            in_features=in_features,\n            out_features=out_features,\n            weight_dtype=weight_dtype,\n            block_size=block_size,\n            bias=bias,\n            dtype=dtype,\n            out_dtype=out_dtype,\n        )\n        num_in_groups = (in_features + block_size[1] - 1) // block_size[1]\n        self.activation_scale = nn.Parameter((num_in_groups,), \"float32\")\n\n    @staticmethod\n    def from_linear(\n        src: nn.Linear,\n        config: BlockScaleQuantize,\n        weight_block_size: Optional[Tuple[int, int]],\n    ) -> \"BlockScaleQuantizeLinearStaticActivation\":\n        \"\"\"\n        Convert a non-quantized nn.Linear to a block-scale quantized BlockScaleQuantizeLinearStaticActivation.  # pylint: disable=line-too-long\n\n        Parameters\n        ----------\n        src : nn.Linear\n            The non-quantized nn.Linear.\n\n        config : BlockScaleQuantize\n            The block-scale quantization config.\n\n        weight_block_size : Optional[Tuple[int, int]]\n            The weight block size.\n\n        Returns\n        -------\n        ret : BlockScaleQuantizeLinearStaticActivation\n            The block-scale quantized BlockScaleQuantizeLinearStaticActivation\n        \"\"\"\n        assert weight_block_size is not None\n        out_features, in_features = src.weight.shape\n        quantized_linear = BlockScaleQuantizeLinearStaticActivation(\n            in_features=in_features,\n            out_features=out_features,\n            weight_dtype=config.weight_dtype,\n            block_size=weight_block_size,\n            bias=getattr(src, \"bias\", None) is not None,\n            dtype=config.model_dtype,\n            out_dtype=src.out_dtype,\n        )\n        if quantized_linear.bias is not None:\n            quantized_linear.bias.attrs = src.bias.attrs\n        if \"shard_strategy\" in src.weight.attrs:\n            shard = src.weight.attrs[\"shard_strategy\"]\n            apply_sharding(shard, shard.name, quantized_linear.weight)\n            if isinstance(shard, tp.ShardSingleDim) and shard.segs is not None:\n                shard.segs = [x // weight_block_size[shard.dim] for x in shard.segs]\n            apply_sharding(shard, f\"{shard.name}_scale_inv\", quantized_linear.weight_scale_inv)\n            apply_sharding(\n                shard,\n                f\"{shard.name}_activation_scale\",\n                quantized_linear.activation_scale,\n            )\n        return quantized_linear\n\n    def forward(self, x: nn.Tensor) -> nn.Tensor:\n        x_fp8 = static_activation_group_quant_fp8(\n            x,\n            self.activation_scale,\n            self.block_size[1],\n            self.weight_dtype,\n        )\n        shape_supported_by_cutlass = (\n            self.weight.shape[0] % 128 == 0 and self.weight.shape[1] % 128 == 0\n        )\n        if (\n            extern.get_store().cutlass_gemm\n            and shape_supported_by_cutlass\n            and tvm.get_global_func(\n                \"cutlass.groupwise_scaled_gemm_e4m3fn_e4m3fn\", allow_missing=True\n            )\n            is not None\n        ):\n            x_scale = broadcast_activation_scale(\n                x,\n                self.activation_scale,\n                transpose=True,\n            )\n            out = cutlass.fp8_groupwise_scaled_gemm(\n                x_fp8,\n                x_scale,\n                self.weight,\n                self.weight_scale_inv,\n                self.block_size,\n                self.out_dtype if self.out_dtype is not None else x.dtype,\n            )\n        else:\n            x_scale_triton = broadcast_activation_scale(\n                x,\n                self.activation_scale,\n                transpose=False,\n            )\n            out = triton.fp8_groupwise_scaled_gemm(\n                x_fp8,\n                x_scale_triton,\n                self.weight,\n                self.weight_scale_inv,\n                self.block_size,\n                self.out_dtype if self.out_dtype is not None else x.dtype,\n            )\n        if self.bias is not None:\n            out = out + self.bias\n        return out\n\n\nclass BlockScaleQuantizeMixtralExperts(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"Block-scale quantization for MoE experts\"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments\n        self,\n        num_local_experts: int,\n        in_features: int,\n        out_features: int,\n        weight_dtype: Literal[\"float8_e4m3fn\", \"float8_e5m2\"],\n        block_size: Tuple[int, int],\n    ) -> None:\n        super().__init__()\n        self.num_local_experts = num_local_experts\n        self.in_features = in_features\n        self.out_features = out_features\n        self.weight = nn.Parameter((num_local_experts, out_features, in_features), weight_dtype)\n        self.weight_scale_inv = nn.Parameter(\n            (\n                num_local_experts,\n                (out_features + block_size[0] - 1) // block_size[0],\n                (in_features + block_size[1] - 1) // block_size[1],\n            ),\n            \"float32\",\n        )\n        self.weight_dtype = weight_dtype\n        self.block_size = block_size\n\n    @staticmethod\n    def from_mixtral_experts(\n        src: \"MixtralExperts\",\n        config: BlockScaleQuantize,\n        weight_block_size: Optional[Tuple[int, int]],\n    ) -> \"BlockScaleQuantizeMixtralExperts\":\n        \"\"\"\n        Converts a non-quantized MixtralExperts to a block-scale\n        quantized BlockScaleQuantizeMixtralExperts\n\n        Parameters\n        ----------\n        src : MixtralExperts\n            The non-quantized MixtralExperts\n\n        config : BlockScaleQuantize\n            The block-scale quantization config.\n\n        weight_block_size : Optional[Tuple[int, int]]\n            The weight block size.\n\n        Returns\n        -------\n        ret : BlockScaleQuantizeMixtralExperts\n            The block-scale quantized BlockScaleQuantizeMixtralExperts layer.\n        \"\"\"\n        assert weight_block_size is not None\n        quantized_mistral_experts = BlockScaleQuantizeMixtralExperts(\n            num_local_experts=src.num_local_experts,\n            in_features=src.in_features,\n            out_features=src.out_features,\n            weight_dtype=config.weight_dtype,\n            block_size=weight_block_size,\n        )\n        if \"shard_strategy\" in src.weight.attrs:\n            shard = src.weight.attrs[\"shard_strategy\"]\n            apply_sharding(shard, shard.name, quantized_mistral_experts.weight)\n            if isinstance(shard, tp.ShardSingleDim) and shard.segs is not None:\n                shard.segs = [x // weight_block_size[shard.dim - 1] for x in shard.segs]\n            apply_sharding(\n                shard,\n                f\"{shard.name}_scale_inv\",\n                quantized_mistral_experts.weight_scale_inv,\n            )\n        return quantized_mistral_experts\n\n    def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor:\n        \"\"\"Forward pass of the block-scale quantized MixtralExperts.\n\n        Parameters\n        ----------\n        x : nn.Tensor\n            The input tensor.\n\n        indptr : nn.Tensor\n            The indptr tensor of group gemm, with shape of [num_experts + 1,].\n\n        Returns\n        -------\n        ret : nn.Tensor\n            The output tensor.\n        \"\"\"\n        if indptr.ndim == 2:\n            # The input is for single token, which does not need group gemm\n            # and can be specialized.\n            expert_indices = indptr\n            assert expert_indices.shape[0] == 1\n            return moe_matmul.dequantize_block_scale_float8_gemv(\n                x,\n                self.weight,\n                self.weight_scale_inv,\n                expert_indices,\n                self.block_size,\n                x.dtype,\n            )\n\n        x_fp8, x_scale = rowwise_group_quant_fp8(\n            x, self.block_size[1], self.weight_dtype, transpose_scale=False\n        )\n        if (\n            extern.get_store().cutlass_gemm\n            and tvm.get_global_func(\n                \"cutlass.groupwise_scaled_group_gemm_e4m3fn_e4m3fn\", allow_missing=True\n            )\n            is not None\n        ):\n            x = cutlass.fp8_groupwise_scaled_group_gemm(\n                x_fp8,\n                x_scale,\n                self.weight,\n                self.weight_scale_inv,\n                indptr,\n                self.block_size,\n                x.dtype,\n            )\n        else:\n            x = triton.fp8_groupwise_scaled_group_gemm(\n                x_fp8,\n                x_scale,\n                self.weight,\n                self.weight_scale_inv,\n                indptr,\n                self.block_size,\n                x.dtype,\n            )\n        return x\n\n    def to(self, dtype: Optional[str] = None) -> None:\n        \"\"\"\n        Override to() such that we do not convert bias if there is an out_dtype.\n        Otherwise, we might run into dtype mismatch when computing x + self.bias.\n        \"\"\"\n        if dtype is not None and isinstance(getattr(self, \"dtype\", None), str):\n            self.dtype = dtype  # pylint: disable=attribute-defined-outside-init\n\n\ndef rowwise_group_quant_fp8(  # pylint: disable=too-many-arguments\n    x: nn.Tensor,\n    group_size: int,\n    dtype: Literal[\"float8_e4m3fn\", \"float8_e5m2\"],\n    transpose_scale: bool,\n    eps: float = 1e-10,\n    keep_first_batch_dim: bool = False,\n) -> Tuple[nn.Tensor, nn.Tensor]:\n    \"\"\"Rowwise group quantization of fp8 tensor.\n\n    Parameters\n    ----------\n    x : nn.Tensor\n        The input tensor.\n\n    group_size : int\n        The group size per row for quantization.\n\n    transpose_scale : bool\n        Whether return the transposed scales or not.\n\n    Returns\n    -------\n    x_fp8 : nn.Tensor\n        The quantized tensor.\n\n    x_scale : nn.Tensor\n        The scales of the quantized tensor.\n        If transpose_scale is True, the shape is\n        (*x.shape[:-2], ceildiv(x.shape[-1], group_size), x.shape[-2]).\n        Otherwise, the shape is (*x.shape[:-1], ceildiv(x.shape[-1], group_size)).\n    \"\"\"\n    assert x.ndim >= 2\n    assert group_size > 0\n\n    def quantize(x: te.Tensor):\n        num_group = tir.ceildiv(x.shape[-1], group_size)\n        max_abs_shape = (*x.shape[:-1], num_group)\n        max_abs_reduce_axis = te.reduce_axis((0, group_size), name=\"r\")\n        scale_dtype = \"float32\"\n        max_abs = te.compute(\n            shape=max_abs_shape,\n            fcompute=lambda *idx: te.max(\n                tir.if_then_else(\n                    idx[-1] * group_size + max_abs_reduce_axis < x.shape[-1],\n                    tir.Max(\n                        te.abs(\n                            x(*idx[:-1], idx[-1] * group_size + max_abs_reduce_axis).astype(\n                                scale_dtype\n                            )\n                        ),\n                        eps,\n                    ),\n                    tir.min_value(scale_dtype),\n                ),\n                axis=max_abs_reduce_axis,\n            ),\n            name=\"max_abs\",\n        )\n        assert dtype in [\"float8_e4m3fn\", \"float8_e5m2\"]\n        fp8_max = 448.0 if dtype == \"float8_e4m3fn\" else 57344.0\n        fp8_min = -fp8_max\n        scale = te.compute(\n            shape=max_abs_shape,\n            fcompute=lambda *idx: max_abs(*idx) / tir.const(fp8_max, scale_dtype),\n            name=\"scale\",\n        )\n        x_quantized = te.compute(\n            shape=x.shape,\n            fcompute=lambda *idx: tir.max(\n                tir.min(\n                    x(*idx).astype(scale_dtype) / scale(*idx[:-1], idx[-1] // group_size),\n                    fp8_max,\n                ),\n                fp8_min,\n            ).astype(dtype),\n            name=\"x_quantized\",\n        )\n        if transpose_scale:\n            if not keep_first_batch_dim:\n                scale = te.compute(\n                    shape=(num_group, *x.shape[:-1]),\n                    fcompute=lambda *idx: scale(*idx[1:], idx[0]),\n                    name=\"scale\",\n                )\n            else:\n                assert len(x.shape) > 2\n                scale = te.compute(\n                    shape=(x.shape[0], num_group, *x.shape[1:-1]),\n                    fcompute=lambda *idx: scale(idx[0], *idx[2:], idx[1]),\n                    name=\"scale\",\n                )\n        return x_quantized, scale\n\n    x_quantized, scale = nn.tensor_expr_op(quantize, name_hint=\"rowwise_group_quant_fp8\", args=[x])\n    return x_quantized, scale\n\n\ndef static_activation_group_quant_fp8(\n    x: nn.Tensor,\n    activation_scale: nn.Tensor,\n    group_size: int,\n    dtype: Literal[\"float8_e4m3fn\", \"float8_e5m2\"],\n) -> nn.Tensor:\n    \"\"\"Quantize activations with a pre-computed scale.\"\"\"\n\n    assert activation_scale.ndim == 1\n\n    def quantize(x: te.Tensor, scale: te.Tensor):\n        fp8_max = 448.0 if dtype == \"float8_e4m3fn\" else 57344.0\n        fp8_min = -fp8_max\n\n        def fcompute(*idx):\n            group_idx = tir.indexdiv(idx[-1], group_size)\n            return tir.max(\n                tir.min(\n                    x(*idx).astype(\"float32\") / scale(group_idx),\n                    fp8_max,\n                ),\n                fp8_min,\n            ).astype(dtype)\n\n        return te.compute(shape=x.shape, fcompute=fcompute, name=\"static_activation_group_fp8\")\n\n    quantized = nn.tensor_expr_op(\n        quantize,\n        name_hint=\"static_activation_group_fp8\",\n        args=[x, activation_scale],\n    )\n    return quantized\n\n\ndef broadcast_activation_scale(\n    x: nn.Tensor,\n    activation_scale: nn.Tensor,\n    transpose: bool,\n) -> nn.Tensor:\n    \"\"\"Broadcast stored activation scales.\"\"\"\n\n    reshape_shape = (1,) * (x.ndim - 1) + (activation_scale.shape[0],)\n    scale = nn.op.reshape(activation_scale, reshape_shape)\n    scale = nn.op.broadcast_to(scale, (*x.shape[:-1], activation_scale.shape[0]))\n    if transpose:\n        axes = list(range(scale.ndim))\n        axes[-1], axes[-2] = axes[-2], axes[-1]\n        scale = nn.op.permute_dims(scale, axes=axes)\n    return scale\n\n\ndef dequantize_float8_groupwise_scaled_gemv(\n    x: nn.Tensor,\n    w: nn.Tensor,\n    w_scale: nn.Tensor,\n    block_size: Tuple[int, int],\n    out_dtype: str,\n) -> nn.Tensor:\n    \"\"\"GEMV for FP8 groupwise scaled quantization.\n\n    Parameters\n    ----------\n    x : Tensor\n        The input tensor of shape (k,)\n\n    w : Tensor\n        The quantized weight tensor of shape (n, k)\n\n    w_scale : Tensor\n        The scale tensor of shape\n        (n // block_size[0], k // block_size[1])\n\n    block_size : Tuple[int, int]\n        The block size of the weight tensor.\n\n    out_dtype : str\n        The output dtype of the GEMV computation.\n    \"\"\"\n    assert x.ndim == 2\n    assert w.ndim == 2\n    assert w_scale.ndim == 2\n    assert x.shape[0] == 1\n    assert x.shape[1] == w.shape[1]\n    _, k = x.shape\n    n, _ = w.shape\n    model_dtype = x.dtype\n    quantize_dtype = w.dtype\n\n    assert (n + block_size[0] - 1) // block_size[0] == w_scale.shape[0]\n    assert (k + block_size[1] - 1) // block_size[1] == w_scale.shape[1]\n\n    def _dequantize(w, s, i, j):\n        return w[i, j].astype(model_dtype) * s[i // block_size[0], j // block_size[1]].astype(\n            model_dtype\n        )\n\n    @T.prim_func(private=True)\n    def _func(\n        x: T.Buffer((1, k), model_dtype),  # type: ignore\n        w: T.Buffer((n, k), quantize_dtype),  # type: ignore\n        w_scale: T.Buffer(  # type: ignore\n            (\n                (n + block_size[0] - 1) // block_size[0],\n                (k + block_size[1] - 1) // block_size[1],\n            ),\n            \"float32\",\n        ),\n        o: T.Buffer((n,), out_dtype),  # type: ignore\n    ):\n        T.func_attr({\"op_pattern\": 4, \"tir.noalias\": True})  # kOutEWiseFusable\n        y = T.sblock_alloc_buffer((n, k), model_dtype)\n        for i1, i2 in T.grid(n, k):\n            with T.sblock(\"dequantize\"):\n                i, j = T.axis.remap(\"SS\", [i1, i2])\n                y[i, j] = _dequantize(w, w_scale, i, j)\n        for i1, i2 in T.grid(n, k):\n            with T.sblock(\"gemv\"):\n                i, j = T.axis.remap(\"SR\", [i1, i2])\n                with T.init():\n                    o[i] = T.cast(T.float16(0), out_dtype)\n                o[i] += (x[0, j] * y[i, j]).astype(out_dtype)\n\n    return nn.op.tensor_ir_op(\n        _func,\n        \"moe_dequantize_gemv\",\n        args=[x, w, w_scale],\n        out=nn.Tensor.placeholder([n], out_dtype),\n    )\n"
  },
  {
    "path": "python/mlc_llm/quantization/fp8_quantization.py",
    "content": "\"\"\"Quantization techniques for FP8\"\"\"\n\nimport numpy as np\nfrom tvm import relax, runtime\nfrom tvm.relax.frontend import nn\n\nfrom mlc_llm.nn import MixtralExperts\n\nfrom ..op import cutlass, extern, moe_matmul\nfrom . import per_tensor_quantization as ptq\nfrom .utils import apply_sharding\n\n\nclass FP8PerTensorQuantizeMixtralExperts(\n    ptq.PerTensorQuantizeMixtralExperts\n):  # pylint: disable=too-many-instance-attributes\n    \"\"\"MixtralExperts with per-tensor quantization in FP8.\"\"\"\n\n    def __init__(\n        self,\n        num_local_experts,\n        in_features,\n        out_features,\n        config: ptq.PerTensorQuantize,\n        name: str,\n        tensor_parallel_shards=1,\n    ):  # pylint: disable=too-many-arguments\n        super().__init__(num_local_experts, in_features, out_features, config, name)\n        self.tensor_parallel_shards = tensor_parallel_shards\n\n    @staticmethod\n    def from_mixtral_experts(\n        src: \"MixtralExperts\",\n        config: ptq.PerTensorQuantize,\n        name: str,\n    ) -> \"FP8PerTensorQuantizeMixtralExperts\":\n        \"\"\"\n        Converts a non-quantized MixtralExperts to a per-tensor quantized MixtralExperts.\n\n        Parameters\n        ----------\n        src : MixtralExperts\n            The non-quantized MixtralExperts\n\n        config : PerTensorQuantize\n            The FP8 quantization weight_config.\n\n        name : str\n            The name of the layer.\n\n        Returns\n        -------\n        ret : MixtralExpertsFP8\n            The per-tensor quantized MixtralExperts.\n        \"\"\"\n        quantized_mistral_experts = FP8PerTensorQuantizeMixtralExperts(\n            num_local_experts=src.num_local_experts,\n            in_features=src.in_features,\n            out_features=src.out_features,\n            config=config,\n            name=name,\n            tensor_parallel_shards=src.tensor_parallel_shards,\n        )\n\n        if \"shard_strategy\" in src.weight.attrs:\n            shard = src.weight.attrs[\"shard_strategy\"]\n            apply_sharding(shard, f\"{shard.name}_q_weight\", quantized_mistral_experts.q_weight)\n            # scale doesn't need to be sharded since it's the same for all shards\n\n        return quantized_mistral_experts\n\n    def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor:  # pylint: disable=invalid-name\n        w = self.q_weight\n\n        if self.config.calibration_mode == \"max\":\n            _, x_scale = self.config.quantize_float8(  # type: ignore\n                x,\n                quantize_dtype=self.config.activation_dtype,\n                storage_dtype=self.config.activation_dtype,\n            )\n            if self.config.tensor_parallel_shards > 1:\n                x_scale = nn.ccl_allreduce(x_scale, \"max\")\n            x_scale = nn.extern(\n                \"mlc_llm.calibration_observer\",\n                [f\"{self.name}.q_calibration_scale\", \"max\", x_scale],\n                out=nn.Tensor.placeholder(x_scale.shape, x_scale.dtype),\n            )\n            x_q = (x / x_scale.astype(x.dtype)).astype(self.config.activation_dtype)\n            x = x_q.astype(self.config.model_dtype) * x_scale.astype(self.config.model_dtype)\n\n        if indptr.ndim == 2:\n            assert indptr.shape[0] == 1\n            return moe_matmul.dequantize_float8_gemv(\n                x, w, self.q_scale, indptr, self.config.weight_dtype\n            )\n\n        if extern.get_store().cutlass_group_gemm:\n            if self.config.calibration_mode == \"inference\":\n                if self.q_calibration_scale is not None:\n                    x /= self.q_calibration_scale.astype(x.dtype)\n                x_q = nn.op.astype(x, dtype=self.config.activation_dtype)\n                x_scale = self.q_calibration_scale\n\n            scale = (\n                x_scale * self.q_scale\n                if self.q_scale is not None\n                else nn.wrap_nested(\n                    relax.Constant(runtime.tensor(np.array([1.0]).astype(\"float32\"))),\n                    \"scale\",\n                )\n            )\n            return cutlass.group_gemm(\n                x_q, w, indptr, scale, self.config.weight_dtype, self.config.model_dtype\n            )\n        # Note: convert_weight is target agnostic, so a fallback must be provided\n        w = nn.tensor_expr_op(\n            self.config.dequantize_float8,\n            \"dequantize\",\n            args=[w, self.q_scale, self.config.weight_dtype],\n        )\n        return moe_matmul.group_gemm(x, w, indptr)\n\n\n# pylint: disable=protected-access\nptq.PerTensorQuantizeMixtralExperts._IMPL[\"fp8\"] = FP8PerTensorQuantizeMixtralExperts\n"
  },
  {
    "path": "python/mlc_llm/quantization/ft_quantization.py",
    "content": "\"\"\"The FasterTransformer quantization config\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Any, Callable, List, Literal, Optional, Tuple\n\nimport tvm\nfrom tvm import DataType, DataTypeCode, IRModule, relax, te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.runtime import Tensor\nfrom tvm.s_tir import dlight as dl\nfrom tvm.target import Target\n\nfrom ..loader import QuantizeMapping\nfrom ..op import faster_transformer_dequantize_gemm\nfrom ..support import logging\nfrom ..support.auto_target import detect_cuda_arch_list\nfrom ..support.style import bold\nfrom .group_quantization import (\n    GroupQuantize,\n    GroupQuantizeEmbedding,\n    GroupQuantizeLinear,\n)\nfrom .utils import is_final_fc, is_moe_gate\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass FTQuantize:  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration for FasterTransformer quantization\"\"\"\n\n    name: str\n    kind: str\n    quantize_dtype: Literal[\"int4\", \"int8\"]\n    storage_dtype: Literal[\"int8\"]\n    model_dtype: Literal[\"float16\"]\n    group_size: Optional[int] = None\n\n    num_elem_per_storage: int = 0\n    max_int_value: int = 0\n\n    def fallback_group_quantize(self) -> GroupQuantize:\n        \"\"\"\n        The fallback group quantization config for other parameters.\n\n        Returns\n        ------\n        quantize: GroupQuantize\n            The group quantization config to fallback.\n        \"\"\"\n        return GroupQuantize(\n            name=self.name,\n            kind=\"group-quant\",\n            group_size=32,  # hardcoded to 32 as only supporting int4 quantization\n            quantize_dtype=self.quantize_dtype,\n            storage_dtype=\"uint32\",\n            model_dtype=self.model_dtype,\n            linear_weight_layout=\"NK\",\n        )\n\n    def __post_init__(self):\n        assert self.kind == \"ft-quant\"\n        quantize_dtype = DataType(self.quantize_dtype)\n        storage_dtype = DataType(self.storage_dtype)\n        assert self.quantize_dtype in [\"int4\", \"int8\"]\n        assert storage_dtype.type_code == DataTypeCode.INT\n        assert self.model_dtype == \"float16\"\n        assert self.group_size in [None, 64, 128]\n        if storage_dtype.bits < quantize_dtype.bits:\n            raise ValueError(\"Storage unit should be greater or equal to quantized element\")\n\n        self.num_elem_per_storage = storage_dtype.bits // quantize_dtype.bits\n        self.max_int_value = (2 ** (quantize_dtype.bits - 1)) - 1\n        self._quantize_func_cache = {}\n\n    def quantize_model(\n        self,\n        model: nn.Module,\n        quant_map: QuantizeMapping,\n        name_prefix: str,\n    ) -> nn.Module:\n        \"\"\"\n        Quantize model with FasterTransformer quantization\n\n        Parameters\n        ----------\n        model : nn.Module\n            The non-quantized nn.Module.\n\n        quant_map : QuantizeMapping\n            The quantize mapping with name mapping and func mapping.\n\n        name_prefix : str\n            The name prefix for visited weight.\n\n        Returns\n        -------\n        ret : nn.Module\n            The quantized nn.Module.\n        \"\"\"\n\n        class _Mutator(nn.Mutator):\n            def __init__(self, config: FTQuantize, quant_map: QuantizeMapping) -> None:\n                super().__init__()\n                self.config = config\n                self.quant_map = quant_map\n\n            def visit_module(self, name: str, node: nn.Module) -> Any:\n                \"\"\"\n                The visiting method for FasterTransformer quantization of nn.Module nodes.\n\n                Parameters\n                ----------\n                name : str\n                    The name of the current node.\n\n                node : nn.Module\n                    The current node of nn.Module to mutate.\n\n                Returns\n                ------\n                ret_node: Any\n                    The new node to replace current node.\n                \"\"\"\n                if isinstance(node, nn.Linear):\n                    weight_name = f\"{name}.weight\"\n                    self.quant_map.param_map[weight_name] = [\n                        f\"{name}.q_weight\",\n                        f\"{name}.q_scale\",\n                    ]\n                    if (\n                        # pylint: disable=too-many-boolean-expressions\n                        is_final_fc(name)\n                        or node.out_dtype == \"float32\"\n                        or (self.config.quantize_dtype == \"int4\" and node.out_features % 8 != 0)\n                        or (self.config.quantize_dtype == \"int8\" and node.out_features % 4 != 0)\n                    ):\n                        # Under any of the conditions we fall back to GroupQuantize\n                        # For `is_final_fc()` see https://github.com/mlc-ai/mlc-llm/issues/1723\n                        # If simply skipping lm_head quantization degrades performance\n                        # Other requirements are from CUTLASS\n                        logger.info(\n                            'Fallback to GroupQuantize for nn.Linear: \"%s\", '\n                            + \"weight.shape: %s, out_dtype: %s\",\n                            bold(name),\n                            node.weight.shape,\n                            node.out_dtype,\n                        )\n                        group_quantize = self.config.fallback_group_quantize()\n                        self.quant_map.map_func[weight_name] = group_quantize.quantize_weight\n                        return GroupQuantizeLinear.from_linear(node, group_quantize)\n                    if not is_moe_gate(name, node):\n                        self.quant_map.map_func[weight_name] = self.config.quantize_weight\n                        return FTQuantizeLinear.from_linear(node, self.config)\n                if isinstance(node, nn.Embedding):\n                    weight_name = f\"{name}.weight\"\n                    self.quant_map.param_map[weight_name] = [\n                        f\"{name}.q_weight\",\n                        f\"{name}.q_scale\",\n                    ]\n                    group_quantize = self.config.fallback_group_quantize()\n                    self.quant_map.map_func[weight_name] = group_quantize.quantize_weight\n                    return GroupQuantizeEmbedding.from_embedding(node, group_quantize)\n                return self.visit(name, node)\n\n        model.to(dtype=self.model_dtype)\n        mutator = _Mutator(self, quant_map)\n        model = mutator.visit(name_prefix, model)\n        return model\n\n    def quantize_weight(self, weight: Tensor) -> List[Tensor]:\n        \"\"\"\n        Quantize weight with FasterTransformer quantization\n\n        Parameters\n        ----------\n        weight : Tensor\n            The original weight.\n\n        Returns\n        ------\n        ret: List[Tensor]\n            The list of FasterTransformer quantized weights.\n        \"\"\"\n        assert tvm.get_global_func(\"relax.ext.cutlass\", True), (\n            \"Cutlass should be enabled in TVM runtime to quantize weight, \"\n            \"but not enabled in current TVM runtime environment. \"\n            \"To enable Cutlass in TVM runtime, please `set(USE_CUTLASS ON)` \"\n            \"in config.cmake when compiling TVM from source\"\n        )\n        assert len(weight.shape) == 2\n        device = weight.device\n        device_type = device._DEVICE_TYPE_TO_NAME[  # pylint: disable=protected-access\n            device.dlpack_device_type()\n        ]\n        if device_type == \"cuda\":\n            target = Target.current()\n            if target is None:\n                target = Target.from_device(device)\n            with target:\n\n                def _create_quantize_func() -> IRModule:\n                    bb = relax.BlockBuilder()  # pylint: disable=invalid-name\n                    weight_var = relax.Var(\n                        \"weight\", relax.TensorStructInfo(weight.shape, weight.dtype)\n                    )\n                    with bb.function(name=\"main\", params=[weight_var]):\n                        with bb.dataflow():\n                            lv0 = bb.emit_te(\n                                self._quantize, weight_var\n                            )  # pylint: disable=invalid-name\n                            lv1 = bb.normalize(lv0[0])\n                            lv2 = bb.emit(\n                                relax.call_pure_packed(\n                                    \"cutlass.ft_preprocess_weight\",\n                                    lv1,\n                                    detect_cuda_arch_list(target=target)[0],\n                                    DataType(self.quantize_dtype).bits == 4,\n                                    sinfo_args=lv1.struct_info,\n                                )\n                            )\n                            gv = bb.emit_output(\n                                relax.Tuple([lv2, lv0[1]])\n                            )  # pylint: disable=invalid-name\n                        bb.emit_func_output(gv)\n                    return bb.finalize()\n\n                def _compile_quantize_func(mod: IRModule) -> Callable:\n                    mod = dl.ApplyDefaultSchedule(  # type: ignore   # pylint: disable=not-callable\n                        dl.gpu.Reduction(),\n                        dl.gpu.GeneralReduction(),\n                        dl.gpu.Fallback(),\n                    )(mod)\n                    ex = relax.build(mod, target=target)\n                    vm = relax.VirtualMachine(ex, device)  # pylint: disable=invalid-name\n                    return vm[\"main\"]\n\n                key = str(\n                    (\n                        int(weight.shape[0]),\n                        int(weight.shape[1]),\n                        weight.dtype,\n                        device_type,\n                    )\n                )\n                quantize_func = self._quantize_func_cache.get(key, None)\n                if quantize_func is None:\n                    logger.info(\"Compiling quantize function for key: %s\", key)\n                    quantize_func = _compile_quantize_func(_create_quantize_func())\n                    self._quantize_func_cache[key] = quantize_func\n                data = quantize_func(weight)\n                return data\n        else:\n            raise NotImplementedError(f\"Device type {device_type} is not supported\")\n\n    def _quantize(  # pylint: disable=too-many-locals\n        self,\n        weight: te.Tensor,\n    ) -> Tuple[te.Tensor, te.Tensor]:\n        \"\"\"FasterTransformer quantization for weight tensor, defined in tensor expression.\"\"\"\n        assert len(weight.shape) == 2\n        n, k = weight.shape\n\n        cur_group_size = k if not self.group_size else self.group_size\n        scale_shape = (tir.ceildiv(k, cur_group_size), n)\n        r = te.reduce_axis((0, cur_group_size), name=\"r\")\n\n        max_abs = te.compute(\n            shape=scale_shape,\n            fcompute=lambda j, i: te.max(\n                tir.if_then_else(\n                    j * cur_group_size + r < k,\n                    te.abs(weight[i, j * cur_group_size + r]),\n                    te.min_value(self.model_dtype),\n                ),\n                axis=r,\n            ),\n            name=\"max_abs_value\",\n        )\n        max_int = tir.const(self.max_int_value, self.model_dtype)\n        scale = te.compute(\n            scale_shape,\n            lambda i, j: max_abs[i, j].astype(self.model_dtype) / max_int,\n            name=\"scale\",\n        )\n        # compute scaled weight\n        quantize_dtype = DataType(self.quantize_dtype)\n        bin_mask = tir.const((1 << quantize_dtype.bits) - 1, self.storage_dtype)\n        scaled_weight = te.compute(\n            shape=weight.shape,\n            fcompute=lambda i, j: (\n                tir.min(\n                    tir.max(\n                        tir.round(weight[i, j] / scale[j // cur_group_size, i]),\n                        -max_int - 1,\n                    ),\n                    max_int,\n                ).astype(self.storage_dtype)\n                & bin_mask\n            ),\n        )\n\n        quantized_weight_shape = (k, tir.ceildiv(n, self.num_elem_per_storage))\n        r = te.reduce_axis((0, self.num_elem_per_storage), name=\"r\")  # pylint: disable=invalid-name\n        quantized_weight = te.compute(\n            shape=quantized_weight_shape,\n            fcompute=lambda j, i: tir.sum(\n                tir.if_then_else(\n                    i * self.num_elem_per_storage + r < n,\n                    scaled_weight[i * self.num_elem_per_storage + r, j]\n                    << (\n                        r.astype(self.storage_dtype)\n                        * tir.const(quantize_dtype.bits, self.storage_dtype)\n                    ),\n                    tir.const(0, self.storage_dtype),\n                ),\n                axis=r,\n            ),\n            name=\"weight\",\n        )\n\n        return quantized_weight, scale\n\n\nclass FTQuantizeLinear(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"An nn.Linear module with FasterTransformer quantization\"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments\n        self,\n        in_features: int,\n        out_features: int,\n        config: FTQuantize,\n        bias: bool = True,\n        out_dtype: Optional[str] = None,\n    ) -> None:\n        super().__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.out_dtype = out_dtype\n        self.config = config\n        cur_group_size = in_features if not config.group_size else config.group_size\n        self.q_weight = nn.Parameter(\n            (in_features, tir.ceildiv(out_features, config.num_elem_per_storage)),\n            config.storage_dtype,\n        )\n        self.q_scale = nn.Parameter(\n            (tir.ceildiv(in_features, cur_group_size), out_features), config.model_dtype\n        )\n        if bias:\n            self.bias = nn.Parameter(\n                (out_features,), config.model_dtype if out_dtype is None else out_dtype\n            )\n        else:\n            self.bias = None\n\n    @staticmethod\n    def from_linear(src: nn.Linear, config: FTQuantize) -> \"FTQuantizeLinear\":\n        \"\"\"\n        Converts a non-quantized nn.Linear to a FasterTransformer quantized FTQuantizeLinear\n\n        Parameters\n        ----------\n        src : nn.Linear\n            The non-quantized nn.Linear.\n\n        config : FTQuantize\n            The FasterTransformer quantization config.\n\n        Returns\n        -------\n        ret : FTQuantizeLinear\n            The FasterTransformer quantized FTQuantizeLinear layer.\n        \"\"\"\n        quantized_linear = FTQuantizeLinear(\n            in_features=src.in_features,\n            out_features=src.out_features,\n            config=config,\n            bias=getattr(src, \"bias\", None) is not None,\n            out_dtype=src.out_dtype,\n        )\n        if quantized_linear.bias is not None:\n            quantized_linear.bias.attrs = src.bias.attrs\n        return quantized_linear\n\n    def forward(self, x: nn.Tensor) -> nn.Tensor:  # pylint: disable=invalid-name\n        \"\"\"\n        Forward method for FasterTransformer quantized linear layer.\n\n        Parameters\n        ----------\n        x : nn.Tensor\n            The input tensor.\n\n        Returns\n        -------\n        ret : nn.Tensor\n            The output tensor for the FasterTransformer quantized linear layer.\n        \"\"\"\n        return faster_transformer_dequantize_gemm(\n            x, self.q_weight, self.q_scale, self.bias, group_size=self.config.group_size\n        )\n\n    def to(self, dtype: Optional[str] = None) -> None:\n        \"\"\"\n        Override to() such that we do not convert bias if there is an out_dtype.\n        Otherwise, we might run into dtype mismatch when computing x + self.bias.\n        \"\"\"\n        self.q_weight.to(dtype=dtype)\n        self.q_scale.to(dtype=dtype)\n        if self.bias is not None and self.out_dtype is None:\n            self.bias.to(dtype=dtype)\n        if dtype is not None and isinstance(getattr(self, \"dtype\", None), str):\n            self.dtype = dtype  # pylint: disable=attribute-defined-outside-init\n"
  },
  {
    "path": "python/mlc_llm/quantization/group_quantization.py",
    "content": "\"\"\"The group quantization config\"\"\"\n\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom typing import Any, List, Literal, Optional, Tuple, Union\n\nfrom tvm import DataType, DataTypeCode, IRModule, relax, te, tir, topi\nfrom tvm.relax.frontend import nn\nfrom tvm.runtime import Tensor\n\nfrom mlc_llm.loader import QuantizeMapping\nfrom mlc_llm.nn import MixtralExperts\nfrom mlc_llm.support import logging\n\nfrom .utils import (\n    apply_sharding,\n    compile_quantize_func,\n    convert_uint_to_float,\n    is_final_fc,\n    is_moe_gate,\n    pack_weight,\n)\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass GroupQuantize:  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration for group quantization\"\"\"\n\n    name: str\n    kind: str\n    group_size: int\n    quantize_dtype: Literal[\"int3\", \"int4\", \"int8\"]\n    storage_dtype: Literal[\"uint32\"]\n    model_dtype: Literal[\"float16\", \"float32\", \"bfloat16\"]\n    linear_weight_layout: Literal[\"KN\", \"NK\"]\n    quantize_embedding: bool = True\n    quantize_final_fc: bool = True\n\n    num_elem_per_storage: int = 0\n    num_storage_per_group: int = 0\n    max_int_value: int = 0\n    tensor_parallel_shards: int = 0\n\n    def __post_init__(self):\n        assert self.kind == \"group-quant\"\n        quantize_dtype = DataType(self.quantize_dtype)\n        storage_dtype = DataType(self.storage_dtype)\n        model_dtype = DataType(self.model_dtype)\n        assert quantize_dtype.type_code == DataTypeCode.INT\n        assert storage_dtype.type_code == DataTypeCode.UINT\n        assert model_dtype.type_code in (DataTypeCode.FLOAT, DataTypeCode.BFLOAT)\n        if storage_dtype.bits < quantize_dtype.bits:\n            raise ValueError(\"Storage unit should be greater or equal to quantized element\")\n\n        self.num_elem_per_storage = storage_dtype.bits // quantize_dtype.bits\n        if self.group_size % self.num_elem_per_storage != 0:\n            raise ValueError(\"Group size should be divisible by numbers of elements per storage\")\n        self.num_storage_per_group = self.group_size // self.num_elem_per_storage\n        self.max_int_value = (2 ** (quantize_dtype.bits - 1)) - 1\n        self.linear_quant_axis = 0 if self.linear_weight_layout == \"KN\" else 1\n        self._quantize_func_cache = {}\n\n    def quantize_model(\n        self,\n        model: nn.Module,\n        quant_map: QuantizeMapping,\n        name_prefix: str,\n    ) -> nn.Module:\n        \"\"\"\n        Quantize model with group quantization\n\n        Parameters\n        ----------\n        model : nn.Module\n            The non-quantized nn.Module.\n\n        quant_map : QuantizeMapping\n            The quantize mapping with name mapping and func mapping.\n\n        name_prefix : str\n            The name prefix for visited weight.\n\n        Returns\n        -------\n        ret : nn.Module\n            The quantized nn.Module.\n        \"\"\"\n\n        class _Mutator(nn.Mutator):\n            def __init__(self, config: GroupQuantize, quant_map: QuantizeMapping) -> None:\n                super().__init__()\n                self.config = config\n                self.quant_map = quant_map\n\n            def visit_module(self, name: str, node: nn.Module) -> Any:\n                \"\"\"\n                The visiting method for group quantization of nn.Module nodes.\n\n                Parameters\n                ----------\n                name : str\n                    The name of the current node.\n\n                node : nn.Module\n                    The current node of nn.Module to mutate.\n\n                Returns\n                ------\n                ret_node: Any\n                    The new node to replace current node.\n                \"\"\"\n                if getattr(node, \"no_quantization\", False):\n                    return node\n\n                if (\n                    isinstance(node, nn.Linear)\n                    and (not is_final_fc(name) or self.config.quantize_final_fc)\n                    and not is_moe_gate(name, node)\n                ):\n                    weight_name = f\"{name}.weight\"\n                    self.quant_map.param_map[weight_name] = [\n                        f\"{name}.q_weight\",\n                        f\"{name}.q_scale\",\n                    ]\n                    self.quant_map.map_func[weight_name] = partial(\n                        self.config.quantize_weight,\n                        output_transpose=self.config.linear_weight_layout == \"KN\",\n                    )\n                    return GroupQuantizeLinear.from_linear(node, self.config)\n                if isinstance(node, nn.Embedding) and self.config.quantize_embedding:\n                    weight_name = f\"{name}.weight\"\n                    self.quant_map.param_map[weight_name] = [\n                        f\"{name}.q_weight\",\n                        f\"{name}.q_scale\",\n                    ]\n                    self.quant_map.map_func[weight_name] = self.config.quantize_weight\n                    return GroupQuantizeEmbedding.from_embedding(node, self.config)\n                if isinstance(node, MixtralExperts):\n                    weight_name = f\"{name}.weight\"\n                    self.quant_map.param_map[weight_name] = [\n                        f\"{name}.q_weight\",\n                        f\"{name}.q_scale\",\n                    ]\n                    self.quant_map.map_func[weight_name] = self.config.quantize_weight\n                    return GroupQuantizeMixtralExperts.from_mixtral_experts(node, self.config)\n                return self.visit(name, node)\n\n        model.to(dtype=self.model_dtype)\n        mutator = _Mutator(self, quant_map)\n        model = mutator.visit(name_prefix, model)\n        return model\n\n    def _dequantize(\n        self,\n        weight: te.Tensor,\n        scale: te.Tensor,\n        axis: int,\n        out_shape: Optional[List[tir.PrimExpr]] = None,\n    ):\n        tir_max_int = tir.const(self.max_int_value, self.model_dtype)\n        float_weight = convert_uint_to_float(\n            weight,\n            DataType(self.quantize_dtype).bits,\n            self.num_elem_per_storage,\n            self.storage_dtype,\n            self.model_dtype,\n            axis=axis,\n            out_shape=out_shape,\n        )\n        if out_shape is None:\n            out_shape = weight.shape\n            out_shape[axis] *= self.num_elem_per_storage\n        axis = axis if axis >= 0 else len(out_shape) + axis\n        return te.compute(\n            shape=out_shape,\n            fcompute=lambda *idx: tir.multiply(\n                tir.subtract(\n                    float_weight(*idx),\n                    tir_max_int,\n                ),\n                scale(*idx[:axis], idx[axis] // self.group_size, *idx[axis + 1 :]),\n            ),\n            name=\"dequantize\",\n        )\n\n    def quantize_weight(\n        self, weight: Tensor, axis: int = -1, output_transpose: bool = False\n    ) -> List[Tensor]:\n        \"\"\"\n        Quantize weight with group quantization\n\n        Parameters\n        ----------\n        weight : Tensor\n            The original weight.\n\n        axis : int\n            The group axis.\n\n        output_transpose : bool\n            Whether to transpose the output quantized weight. Only 2D weight is supported.\n\n        Returns\n        ------\n        ret: List[Tensor]\n            The list of group quantized weights.\n        \"\"\"\n        device = weight.device\n        device_type = device._DEVICE_TYPE_TO_NAME[  # pylint: disable=protected-access\n            device.dlpack_device_type()\n        ]\n        axis = axis if axis >= 0 else len(weight.shape) + axis\n\n        def _create_quantize_func() -> IRModule:\n            bb = relax.BlockBuilder()  # pylint: disable=invalid-name\n            weight_var = relax.Var(\"weight\", relax.TensorStructInfo(weight.shape, weight.dtype))\n            with bb.function(name=\"main\", params=[weight_var]):\n                with bb.dataflow():\n                    lv = bb.emit_te(self._quantize, weight_var, axis, output_transpose)\n                    gv = bb.emit_output(lv)  # pylint: disable=invalid-name\n                bb.emit_func_output(gv)\n            return bb.finalize()\n\n        key = (\n            f\"({weight.shape}, {weight.dtype}, {device_type}, \"\n            f\"axis={axis}, output_transpose={output_transpose})\"\n        )\n        quantize_func = self._quantize_func_cache.get(key, None)\n        if quantize_func is None:\n            logger.info(\"Compiling quantize function for key: %s\", key)\n            quantize_func = compile_quantize_func(_create_quantize_func(), device=device)\n            self._quantize_func_cache[key] = quantize_func\n        return quantize_func(weight)\n\n    def _quantize(  # pylint: disable=too-many-locals\n        self,\n        weight: te.Tensor,\n        axis: int = -1,\n        output_transpose: bool = False,\n    ) -> Tuple[te.Tensor, te.Tensor]:\n        \"\"\"Group quantization for weight tensor, defined in tensor expression.\"\"\"\n        max_int = tir.const(self.max_int_value, self.model_dtype)\n        shape = weight.shape  # pylint: disable=invalid-name\n        axis = axis if axis >= 0 else len(shape) + axis\n        k = shape[axis]\n        # compute scale per group\n        r = te.reduce_axis((0, self.group_size), name=\"r\")  # pylint: disable=invalid-name\n        num_group = tir.ceildiv(k, self.group_size)\n        scale_shape = (*shape[:axis], num_group, *shape[axis + 1 :])\n        max_abs = te.compute(\n            shape=scale_shape,\n            fcompute=lambda *idx: te.max(\n                tir.if_then_else(\n                    idx[axis] * self.group_size + r < k,\n                    te.abs(\n                        weight(\n                            *idx[:axis],\n                            idx[axis] * self.group_size + r,\n                            *idx[axis + 1 :],\n                        )\n                    ),\n                    te.min_value(self.model_dtype),\n                ),\n                axis=r,\n            ),\n            name=\"max_abs_value\",\n        )\n        scale = te.compute(\n            scale_shape,\n            lambda *idx: max_abs(*idx).astype(self.model_dtype) / max_int,\n            name=\"scale\",\n        )\n        # compute scaled weight\n        scaled_weight = te.compute(\n            shape=weight.shape,\n            fcompute=lambda *idx: tir.min(\n                tir.max(\n                    tir.round(\n                        weight(*idx)\n                        / scale(*idx[:axis], idx[axis] // self.group_size, *idx[axis + 1 :])\n                        + max_int\n                    ),\n                    tir.const(0, self.model_dtype),\n                ),\n                max_int * 2,\n            ).astype(self.storage_dtype),\n        )\n        # compute quantized weight per storage\n        num_storage = self.num_storage_per_group * num_group\n        quantized_weight_shape = (*shape[:axis], num_storage, *shape[axis + 1 :])\n        quantized_weight = pack_weight(\n            scaled_weight,\n            axis=axis,\n            num_elem_per_storage=self.num_elem_per_storage,\n            weight_dtype=self.quantize_dtype,\n            storage_dtype=self.storage_dtype,\n            out_shape=quantized_weight_shape,\n        )\n        if output_transpose:\n            if len(quantized_weight.shape) != 2 or len(scale.shape) != 2:\n                raise ValueError(\n                    \"Does not support transpose output quantized weight with ndim != 2\"\n                )\n            quantized_weight = topi.transpose(quantized_weight)\n            scale = topi.transpose(scale)\n        return quantized_weight, scale\n\n\nclass GroupQuantizeLinear(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"An nn.Linear module with group quantization\"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments\n        self,\n        in_features: int,\n        out_features: Union[int, tir.Var],\n        config: GroupQuantize,\n        bias: bool = True,\n        out_dtype: Optional[str] = None,\n    ) -> None:\n        super().__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.out_dtype = out_dtype\n        self.config = config\n        num_group = tir.ceildiv(in_features, config.group_size)\n        num_shards = config.tensor_parallel_shards\n        if num_shards > 1 and (in_features * num_shards // config.group_size) % num_shards != 0:\n            raise ValueError(\n                f\"The linear dimension {in_features * num_shards} has \"\n                f\"{in_features * num_shards // config.group_size} groups under group size \"\n                f\"{config.group_size}. The groups cannot be evenly distributed on \"\n                f\"{num_shards} GPUs.\\n\"\n                \"Possible solutions: reduce number of GPUs, or use quantization with smaller \"\n                \"group size.\"\n            )\n        if config.linear_weight_layout == \"KN\":\n            self.q_weight = nn.Parameter(\n                (config.num_storage_per_group * num_group, out_features),\n                config.storage_dtype,\n            )\n            self.q_scale = nn.Parameter((num_group, out_features), config.model_dtype)\n        else:\n            self.q_weight = nn.Parameter(\n                (out_features, config.num_storage_per_group * num_group),\n                config.storage_dtype,\n            )\n            self.q_scale = nn.Parameter((out_features, num_group), config.model_dtype)\n        if bias:\n            self.bias = nn.Parameter(\n                (out_features,), config.model_dtype if out_dtype is None else out_dtype\n            )\n        else:\n            self.bias = None\n\n    @staticmethod\n    def from_linear(src: nn.Linear, config: GroupQuantize) -> \"GroupQuantizeLinear\":\n        \"\"\"\n        Converts a non-quantized nn.Linear to a group quantized GroupQuantizeLinear\n\n        Parameters\n        ----------\n        src : nn.Linear\n            The non-quantized nn.Linear.\n\n        config : GroupQuantize\n            The group quantization config.\n\n        Returns\n        -------\n        ret : GroupQuantizeLinear\n            The group quantized GroupQuantizeLinear layer.\n        \"\"\"\n        # For dynamic shape, src.out_features is `\"name\"`; src.weight.shape[0] is `tir.Var(\"name\")`\n        out_features, in_features = src.weight.shape\n        quantized_linear = GroupQuantizeLinear(\n            in_features=in_features,\n            out_features=out_features,\n            config=config,\n            bias=getattr(src, \"bias\", None) is not None,\n            out_dtype=src.out_dtype,\n        )\n        if quantized_linear.bias is not None:\n            quantized_linear.bias.attrs = src.bias.attrs\n        if \"shard_strategy\" in src.weight.attrs:\n            shard = src.weight.attrs[\"shard_strategy\"]\n            apply_sharding(shard, f\"{shard.name}_q_weight\", quantized_linear.q_weight)\n            apply_sharding(shard, f\"{shard.name}_q_scale\", quantized_linear.q_scale)\n        return quantized_linear\n\n    def forward(self, x: nn.Tensor) -> nn.Tensor:  # pylint: disable=invalid-name\n        \"\"\"\n        Forward method for group quantized linear layer.\n\n        Parameters\n        ----------\n        x : nn.Tensor\n            The input tensor.\n\n        Returns\n        -------\n        ret : nn.Tensor\n            The output tensor for the group quantized linear layer.\n        \"\"\"\n        w = nn.op.tensor_expr_op(  # pylint: disable=invalid-name\n            lambda weight, scale: self.config._dequantize(  # pylint: disable=protected-access\n                weight,\n                scale,\n                axis=self.config.linear_quant_axis,\n                out_shape=(\n                    [\n                        (\n                            tir.IntImm(\"int64\", self.out_features)\n                            if isinstance(self.out_features, int)\n                            else weight.shape[0]\n                        ),  # Reuse same tir.Var for symbolic shape (after Exporter)\n                        tir.IntImm(\"int64\", self.in_features),\n                    ]\n                    if self.config.linear_weight_layout == \"NK\"\n                    else [\n                        tir.IntImm(\"int64\", self.in_features),\n                        (\n                            tir.IntImm(\"int64\", self.out_features)\n                            if isinstance(self.out_features, int)\n                            else weight.shape[1]\n                        ),  # Reuse same tir.Var for symbolic shape (after Exporter)\n                    ]\n                ),\n            ),\n            name_hint=\"dequantize\",\n            args=[self.q_weight, self.q_scale],\n        )\n        if self.config.linear_weight_layout == \"NK\":\n            w = nn.op.permute_dims(w)  # pylint: disable=invalid-name\n        x = nn.op.matmul(x, w, out_dtype=self.out_dtype)\n        if self.bias is not None:\n            x = x + self.bias\n        return x\n\n    def to(self, dtype: Optional[str] = None) -> None:\n        \"\"\"\n        Override to() such that we do not convert bias if there is an out_dtype.\n        Otherwise, we might run into dtype mismatch when computing x + self.bias.\n        \"\"\"\n        self.q_weight.to(dtype=dtype)\n        self.q_scale.to(dtype=dtype)\n        if self.bias is not None and self.out_dtype is None:\n            self.bias.to(dtype=dtype)\n        if dtype is not None and isinstance(getattr(self, \"dtype\", None), str):\n            self.dtype = dtype  # pylint: disable=attribute-defined-outside-init\n\n\nclass GroupQuantizeEmbedding(nn.Module):\n    \"\"\"An nn.Embedding module with group quantization\"\"\"\n\n    def __init__(self, num: Union[int, tir.Var], dim: int, config: GroupQuantize):\n        self.num = num\n        self.dim = dim\n        self.config = config\n        num_group = tir.ceildiv(dim, config.group_size)\n        self.q_weight = nn.Parameter(\n            (num, config.num_storage_per_group * num_group), config.storage_dtype\n        )\n        self.q_scale = nn.Parameter((num, num_group), config.model_dtype)\n\n    @staticmethod\n    def from_embedding(embedding: nn.Embedding, config: GroupQuantize) -> \"GroupQuantizeEmbedding\":\n        \"\"\"\n        Converts a non-quantized nn.Embedding to a group quantized GroupQuantizeEmbedding\n\n        Parameters\n        ----------\n        linear : nn.Embedding\n            The non-quantized nn.Embedding.\n\n        config : GroupQuantize\n            The group quantization config.\n\n        Returns\n        -------\n        ret : GroupQuantizeEmbedding\n            The group quantized GroupQuantizeEmbedding layer.\n        \"\"\"\n        num, dim = embedding.weight.shape\n        return GroupQuantizeEmbedding(num, dim, config)\n\n    def forward(self, x: nn.Tensor):  # pylint: disable=invalid-name\n        \"\"\"\n        Forward method for group quantized embedding layer.\n\n        Parameters\n        ----------\n        x : nn.Tensor\n            The input tensor.\n\n        Returns\n        -------\n        ret : nn.Tensor\n            The output tensor for the embedding layer.\n        \"\"\"\n        w = nn.op.tensor_expr_op(  # pylint: disable=invalid-name\n            lambda weight, scale: self.config._dequantize(  # pylint: disable=protected-access\n                weight,\n                scale,\n                axis=-1,\n                out_shape=[\n                    (\n                        tir.IntImm(\"int64\", self.num)\n                        if isinstance(self.num, int)\n                        else weight.shape[0]\n                    ),  # Reuse same tir.Var for symbolic shape (after Exporter)\n                    tir.IntImm(\"int64\", self.dim),\n                ],\n            ),\n            name_hint=\"dequantize\",\n            args=[self.q_weight, self.q_scale],\n        )\n        if x.ndim == 1:\n            return nn.op.take(w, x, axis=0)\n        return nn.op.reshape(\n            nn.op.take(w, nn.op.reshape(x, shape=[-1]), axis=0),\n            shape=[*x.shape, self.dim],\n        )\n\n    def lm_head_forward(self, x: nn.Tensor):\n        \"\"\"The lm_head forwarding, which dequantizes the weight\n        and multiplies it with the input tensor.\n\n        Parameters\n        ----------\n        x : nn.Tensor\n            The input tensor.\n\n        Returns\n        -------\n        ret : nn.Tensor\n            The output tensor for the lm_head layer.\n        \"\"\"\n        w = nn.op.tensor_expr_op(  # pylint: disable=invalid-name\n            lambda weight, scale: self.config._dequantize(  # pylint: disable=protected-access\n                weight,\n                scale,\n                axis=-1,\n                out_shape=[\n                    (\n                        tir.IntImm(\"int64\", self.num)\n                        if isinstance(self.num, int)\n                        else weight.shape[0]\n                    ),\n                    tir.IntImm(\"int64\", self.dim),\n                ],\n            ),\n            name_hint=\"dequantize\",\n            args=[self.q_weight, self.q_scale],\n        )\n        w = nn.op.permute_dims(w)\n        return nn.op.matmul(x, w, out_dtype=\"float32\")\n\n\nclass GroupQuantizeMixtralExperts(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"An MixtralExperts module with group quantization\"\"\"\n\n    def __init__(\n        self,\n        num_local_experts,\n        in_features,\n        out_features,\n        config: GroupQuantize,\n    ):  # pylint: disable=too-many-arguments\n        self.num_local_experts = num_local_experts\n        self.in_features = in_features\n        self.out_features = out_features\n        self.config = config\n        num_group = tir.ceildiv(in_features, config.group_size)\n        self.q_weight = nn.Parameter(\n            (num_local_experts, out_features, config.num_storage_per_group * num_group),\n            config.storage_dtype,\n        )\n        self.q_scale = nn.Parameter(\n            (num_local_experts, out_features, num_group), config.model_dtype\n        )\n        self.quantize_dtype = config.quantize_dtype\n        self.group_size = config.group_size\n        self.dtype = config.model_dtype\n        if config.linear_weight_layout == \"KN\":\n            raise NotImplementedError(\"GroupQuantizeMixtralExperts does not support KN layout now.\")\n\n    @staticmethod\n    def from_mixtral_experts(\n        src: \"MixtralExperts\", config: GroupQuantize\n    ) -> \"GroupQuantizeMixtralExperts\":\n        \"\"\"\n        Converts a non-quantized MixtralExperts to a group quantized GroupQuantizeMixtralExperts\n\n        Parameters\n        ----------\n        src : MixtralExperts\n            The non-quantized MixtralExperts\n\n        config : GroupQuantize\n            The group quantization config.\n\n        Returns\n        -------\n        ret : GroupQuantizeMixtralExperts\n            The group quantized GroupQuantizeMixtralExperts layer.\n        \"\"\"\n        quantized_mistral_experts = GroupQuantizeMixtralExperts(\n            num_local_experts=src.num_local_experts,\n            in_features=src.in_features,\n            out_features=src.out_features,\n            config=config,\n        )\n        if \"shard_strategy\" in src.weight.attrs:\n            shard = src.weight.attrs[\"shard_strategy\"]\n            apply_sharding(shard, f\"{shard.name}_q_weight\", quantized_mistral_experts.q_weight)\n            apply_sharding(shard, f\"{shard.name}_q_scale\", quantized_mistral_experts.q_scale)\n        return quantized_mistral_experts\n\n    def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor:  # pylint: disable=invalid-name\n        \"\"\"Forward method for group quantized mistral experts.\n\n        Parameters\n        ----------\n        x : nn.Tensor\n            The input tensor.\n\n        indptr: nn.Tensor\n            The indptr tensor\n\n        single_batch_decode: bool\n            Whether to use single-batch decode\n\n        Returns\n        -------\n        ret : nn.Tensor\n            The output tensor for the group quantized mistral experts layer.\n        \"\"\"\n        from mlc_llm.op import moe_matmul  # pylint: disable=import-outside-toplevel\n\n        assert x.ndim == 2\n        if indptr.ndim == 2:  # single-batch\n            assert indptr.shape[0] == 1\n            return moe_matmul.dequantize_gemv(\n                x,\n                self.q_weight,\n                self.q_scale,\n                indptr,\n                quantize_dtype=self.quantize_dtype,\n                group_size=self.group_size,\n            )\n        assert indptr.ndim == 1\n        return moe_matmul.dequantize_group_gemm(\n            x,\n            self.q_weight,\n            self.q_scale,\n            indptr,\n            quantize_dtype=self.quantize_dtype,\n            indptr_dtype=indptr.dtype,\n            group_size=self.group_size,\n        )\n"
  },
  {
    "path": "python/mlc_llm/quantization/model_quantization.py",
    "content": "\"\"\"Quantization factory utilities for model quantization.\"\"\"\n\nfrom typing import Any, Callable, Dict, Optional, Tuple, Type\n\nfrom tvm.relax.frontend import nn\n\nfrom mlc_llm.loader import QuantizeMapping\n\nfrom .awq_quantization import AWQQuantize\nfrom .block_scale_quantization import BlockScaleQuantize\nfrom .ft_quantization import FTQuantize\nfrom .group_quantization import GroupQuantize\nfrom .no_quantization import NoQuantize\nfrom .per_tensor_quantization import PerTensorQuantize\nfrom .quantization import Quantization\n\nFuncQuantization = Callable[[Any, Quantization], Tuple[nn.Module, QuantizeMapping]]\n\n\ndef make_quantization_functions(  # pylint: disable=too-many-arguments, too-many-locals\n    model_cls: Type[nn.Module],\n    *,\n    model_ctor: Optional[Callable[[Any], nn.Module]] = None,\n    supports_group_quant: bool = True,\n    supports_ft_quant: bool = True,\n    supports_awq: bool = False,\n    awq_unsupported_message: Optional[str] = None,\n    supports_per_tensor: bool = False,\n    supports_block_scale: bool = False,\n    set_tensor_parallel_shards: bool = True,\n    per_tensor_use_shards: bool = True,\n) -> Dict[str, FuncQuantization]:\n    \"\"\"Create standard quantization function implementations for a model class.\"\"\"\n\n    def _create_model(model_config: Any) -> nn.Module:\n        if model_ctor is not None:\n            return model_ctor(model_config)\n        return model_cls(model_config)\n\n    def _no_quant(model_config: Any, quantization: NoQuantize) -> Tuple[nn.Module, QuantizeMapping]:\n        model = _create_model(model_config)\n        model.to(quantization.model_dtype)\n        return model, QuantizeMapping({}, {})\n\n    def _group_quant(\n        model_config: Any,\n        quantization: GroupQuantize,\n    ) -> Tuple[nn.Module, QuantizeMapping]:\n        model = _create_model(model_config)\n        model.to(quantization.model_dtype)\n        quant_map = QuantizeMapping({}, {})\n        if set_tensor_parallel_shards:\n            if not hasattr(model_config, \"tensor_parallel_shards\"):\n                raise AttributeError(\n                    \"model_config is missing required \"\n                    \"attribute 'tensor_parallel_shards' for group quantization\"\n                )\n            quantization.tensor_parallel_shards = getattr(model_config, \"tensor_parallel_shards\")\n        model = quantization.quantize_model(\n            model,\n            quant_map,\n            \"\",\n        )\n        return model, quant_map\n\n    def _ft_quant(model_config: Any, quantization: FTQuantize) -> Tuple[nn.Module, QuantizeMapping]:\n        model = _create_model(model_config)\n        model.to(quantization.model_dtype)\n        quant_map = QuantizeMapping({}, {})\n        model = quantization.quantize_model(\n            model,\n            quant_map,\n            \"\",\n        )\n        return model, quant_map\n\n    def _awq_quant(\n        model_config: Any, quantization: AWQQuantize\n    ) -> Tuple[nn.Module, QuantizeMapping]:\n        if awq_unsupported_message is not None:\n            raise NotImplementedError(awq_unsupported_message)\n        model = _create_model(model_config)\n        model.to(quantization.model_dtype)\n        quant_map = QuantizeMapping({}, {})\n        model = quantization.quantize_model(\n            model,\n            quant_map,\n            \"\",\n        )\n        return model, quant_map\n\n    def _per_tensor_quant(\n        model_config: Any,\n        quantization: PerTensorQuantize,\n    ) -> Tuple[nn.Module, QuantizeMapping]:\n        model = _create_model(model_config)\n        model.to(quantization.model_dtype)\n        quant_map = QuantizeMapping({}, {})\n        kwargs = {}\n        if per_tensor_use_shards:\n            if not hasattr(model_config, \"tensor_parallel_shards\"):\n                raise AttributeError(\n                    \"model_config is missing required attribute \"\n                    \"'tensor_parallel_shards' for per-tensor quantization\"\n                )\n            kwargs[\"tensor_parallel_shards\"] = getattr(model_config, \"tensor_parallel_shards\")\n        model = quantization.quantize_model(\n            model,\n            quant_map,\n            \"\",\n            **kwargs,\n        )\n        return model, quant_map\n\n    def _block_scale_quant(\n        model_config: Any,\n        quantization: BlockScaleQuantize,\n    ) -> Tuple[nn.Module, QuantizeMapping]:\n        model = _create_model(model_config)\n        model.to(quantization.model_dtype)\n        quant_map = QuantizeMapping({}, {})\n        model = quantization.quantize_model(model, quant_map, \"\")\n        return model, quant_map\n\n    quantize_fns: Dict[str, FuncQuantization] = {\"no-quant\": _no_quant}\n    if supports_group_quant:\n        quantize_fns[\"group-quant\"] = _group_quant\n    if supports_ft_quant:\n        quantize_fns[\"ft-quant\"] = _ft_quant\n    if supports_awq:\n        quantize_fns[\"awq\"] = _awq_quant\n    if supports_per_tensor:\n        quantize_fns[\"per-tensor-quant\"] = _per_tensor_quant\n    if supports_block_scale:\n        quantize_fns[\"block-scale-quant\"] = _block_scale_quant\n    return quantize_fns\n\n\ndef make_awq_quant(\n    model_cls: Type[nn.Module],\n) -> Callable[[Any, AWQQuantize], Tuple[nn.Module, QuantizeMapping]]:\n    \"\"\"Create a standard AWQ quantization function for loaders.\"\"\"\n\n    def awq_quant(\n        model_config: Any, quantization: AWQQuantize\n    ) -> Tuple[nn.Module, QuantizeMapping]:\n        model = model_cls(model_config)\n        model.to(quantization.model_dtype)\n        quant_map = QuantizeMapping({}, {})\n        model = quantization.quantize_model(\n            model,\n            quant_map,\n            \"\",\n        )\n        return model, quant_map\n\n    return awq_quant\n"
  },
  {
    "path": "python/mlc_llm/quantization/no_quantization.py",
    "content": "\"\"\"The no quantization config\"\"\"\n\nfrom dataclasses import dataclass\n\n\n@dataclass\nclass NoQuantize:  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration for no quantization\"\"\"\n\n    name: str\n    kind: str\n    model_dtype: str  # \"float16\", \"float32\"\n\n    def __post_init__(self):\n        assert self.kind == \"no-quant\"\n"
  },
  {
    "path": "python/mlc_llm/quantization/per_tensor_quantization.py",
    "content": "\"\"\"The per-tensor quantization config\"\"\"\n\nimport functools\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union\n\nimport numpy as np\nfrom tvm import DataType, DataTypeCode, IRModule, relax, runtime, te, tir, topi\nfrom tvm.relax.frontend import nn\nfrom tvm.runtime import Tensor\n\nfrom mlc_llm.loader import QuantizeMapping\nfrom mlc_llm.nn import MixtralExperts\nfrom mlc_llm.op import cutlass, extern\nfrom mlc_llm.support import logging\n\nfrom .utils import (\n    apply_sharding,\n    compile_quantize_func,\n    convert_uint_packed_fp8_to_float,\n    is_final_fc,\n    is_moe_gate,\n    pack_weight,\n)\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass PerTensorQuantize:  # pylint: disable=too-many-instance-attributes\n    \"\"\"Configuration for per-tensor quantization\"\"\"\n\n    name: str\n    kind: str\n    activation_dtype: Literal[\"float8_e4m3fn\", \"float8_e5m2\"]\n    weight_dtype: Literal[\"float8_e4m3fn\", \"float8_e5m2\"]\n    storage_dtype: Literal[\"uint32\", \"float8_e4m3fn\", \"float8_e5m2\"]\n    model_dtype: Literal[\"float16\"]\n    quantize_embedding: bool = True\n    quantize_final_fc: bool = True\n    quantize_linear: bool = True\n\n    num_elem_per_storage: int = 0\n    max_int_value: int = 0\n    use_scale: bool = True\n    # The calibration mode for quantization. If set to \"inference\", the model is built for\n    # inference. This should be used after calibration is done.\n    # If set to \"max\", the model is built for calibration that computes the scale using max value of\n    # the activations.\n    calibration_mode: Literal[\"inference\", \"max\"] = \"inference\"\n    tensor_parallel_shards: int = 1\n\n    def __post_init__(self):\n        assert self.kind == \"per-tensor-quant\"\n        self.num_elem_per_storage = (\n            DataType(self.storage_dtype).bits // DataType(self.weight_dtype).bits\n        )\n        self.max_int_value = int(tir.max_value(self.weight_dtype).value)\n        self._quantize_func_cache = {}\n\n    def quantize_model(\n        self,\n        model: nn.Module,\n        quant_map: QuantizeMapping,\n        name_prefix: str,\n        tensor_parallel_shards: int,\n    ) -> nn.Module:\n        \"\"\"\n        Quantize model with per-tensor quantization\n\n        Parameters\n        ----------\n        model : nn.Module\n            The non-quantized nn.Module.\n\n        quant_map : QuantizeMapping\n            The quantize mapping with name mapping and func mapping.\n\n        name_prefix : str\n            The name prefix for visited weight.\n\n        tensor_parallel_shards : int\n            The number of tensor parallel shards.\n\n        Returns\n        -------\n        ret : nn.Module\n            The quantized nn.Module.\n        \"\"\"\n\n        self.tensor_parallel_shards = tensor_parallel_shards\n\n        class _Mutator(nn.Mutator):\n            def __init__(self, config: PerTensorQuantize, quant_map: QuantizeMapping) -> None:\n                super().__init__()\n                self.config = config\n                self.quant_map = quant_map\n\n            def visit_module(self, name: str, node: nn.Module) -> Any:\n                \"\"\"\n                The visiting method for per-tensor quantization of nn.Module nodes.\n\n                Parameters\n                ----------\n                name : str\n                    The name of the current node.\n\n                node : nn.Module\n                    The current node of nn.Module to mutate.\n\n                Returns\n                ------\n                ret_node: Any\n                    The new node to replace current node.\n                \"\"\"\n                weight_name = f\"{name}.weight\"\n                param_names = (\n                    [f\"{name}.q_weight\", f\"{name}.q_scale\"]\n                    if self.config.use_scale\n                    else [\n                        f\"{name}.q_weight\",\n                    ]\n                )\n                if (\n                    isinstance(node, nn.Linear)\n                    and self.config.quantize_linear\n                    and (not is_final_fc(name) or self.config.quantize_final_fc)\n                    and not is_moe_gate(name, node)\n                ):\n                    self.quant_map.param_map[weight_name] = param_names\n                    self.quant_map.map_func[weight_name] = self.config.quantize_weight\n                    op = PerTensorQuantizeLinear.from_linear(node, self.config, name)\n                elif isinstance(node, nn.Embedding) and self.config.quantize_embedding:\n                    self.quant_map.param_map[weight_name] = param_names\n                    self.quant_map.map_func[weight_name] = self.config.quantize_weight\n                    op = PerTensorQuantizeEmbedding.from_embedding(node, self.config)\n                elif isinstance(node, MixtralExperts):\n                    self.quant_map.param_map[weight_name] = param_names\n                    self.quant_map.map_func[weight_name] = self.config.quantize_weight\n                    op = PerTensorQuantizeMixtralExperts.from_mixtral_experts(\n                        node, self.config, name\n                    )\n                else:\n                    return self.visit(name, node)\n\n                if hasattr(op, \"q_calibration_scale\") and op.q_calibration_scale:\n                    # update quant_map for calibration scale\n                    param_name = f\"{name}.q_calibration_scale\"\n                    old_map_func = self.quant_map.map_func[weight_name]\n\n                    def map_func(*args, **kwargs):\n                        # placeholder for calibration scale, the actual value will be set after\n                        # calibration.\n                        scale = runtime.empty(\n                            shape=op.q_calibration_scale.shape,\n                            dtype=op.q_calibration_scale.dtype,\n                        )\n                        return [*old_map_func(*args, **kwargs), scale]\n\n                    self.quant_map.param_map[weight_name].append(param_name)\n                    self.quant_map.map_func[weight_name] = map_func\n                return op\n\n        model.to(dtype=self.model_dtype)\n        mutator = _Mutator(self, quant_map)\n        model = mutator.visit(name_prefix, model)\n        return model\n\n    def quantize_weight(self, weight) -> List[Tensor]:\n        \"\"\"\n        Quantize weight with per-tensor quantization.\n\n        Parameters\n        ----------\n        weight : Tensor\n            The weight to quantize.\n\n        Returns\n        -------\n        ret : List[Tensor]\n            The quantized weight and the scale if use_scale is True.\n        \"\"\"\n        device = weight.device\n        device_type = device._DEVICE_TYPE_TO_NAME[  # pylint: disable=protected-access\n            device.dlpack_device_type()\n        ]\n\n        def _create_quantize_func() -> IRModule:\n            if DataType(self.weight_dtype).type_code in [\n                DataTypeCode.Float8E4M3FN,\n                DataTypeCode.Float8E5M2,\n            ]:\n                quantize_func = functools.partial(\n                    self.quantize_float8,\n                    quantize_dtype=self.weight_dtype,\n                    storage_dtype=self.storage_dtype,\n                )\n            else:\n                assert NotImplementedError()\n\n            class Quantizer(nn.Module):\n                \"\"\"Quantizer module for per-tensor quantization.\"\"\"\n\n                def main(self, weight: nn.Tensor):  # pylint: disable=missing-function-docstring\n                    return quantize_func(weight)\n\n            mod = Quantizer()\n            mod, _ = mod.export_tvm(  # pylint: disable=unbalanced-tuple-unpacking\n                spec={\"main\": {\"weight\": nn.spec.Tensor(weight.shape, weight.dtype)}}\n            )\n            return mod\n\n        key = f\"({weight.shape}, {weight.dtype}, {device_type}\"\n        quantize_func = self._quantize_func_cache.get(key, None)\n        if quantize_func is None:\n            logger.info(\"Compiling quantize function for key: %s\", key)\n            quantize_func = compile_quantize_func(_create_quantize_func(), device)\n            self._quantize_func_cache[key] = quantize_func\n        return quantize_func(weight)\n\n    def quantize_float8(  # pylint: disable=too-many-locals\n        self,\n        tensor: nn.Tensor,\n        quantize_dtype: str,\n        storage_dtype: str,\n    ) -> Union[Tuple[nn.Tensor], Tuple[nn.Tensor, nn.Tensor]]:\n        \"\"\"Per-tensor quantization for weight tensor, defined in tensor expression.\"\"\"\n\n        if self.use_scale:\n            # min_scaling_factor taken from TRT-LLM\n            def _compute_scale(x: te.Tensor) -> te.Tensor:\n                max_abs = topi.max(topi.abs(x))\n                min_scaling_factor = tir.const(1.0 / (self.max_int_value * 512.0), self.model_dtype)\n                scale = topi.maximum(\n                    max_abs.astype(self.model_dtype) / self.max_int_value,\n                    min_scaling_factor,\n                ).astype(\"float32\")\n                scale = topi.expand_dims(scale, axis=0)\n                return scale\n\n            scale = nn.tensor_expr_op(_compute_scale, \"compute_scale\", args=[tensor])\n        else:\n            scale = None\n\n        def _compute_quantized_tensor(weight: te.Tensor, scale: Optional[te.Tensor]) -> te.Tensor:\n            elem_storage_dtype = (\n                f\"uint{DataType(quantize_dtype).bits}\"\n                if DataType(storage_dtype).type_code == DataTypeCode.UINT\n                else quantize_dtype\n            )\n            scaled_tensor = te.compute(\n                shape=weight.shape,\n                fcompute=lambda *idx: tir.Cast(\n                    self.storage_dtype,\n                    tir.reinterpret(\n                        elem_storage_dtype,\n                        tir.Cast(\n                            quantize_dtype,\n                            weight(*idx) / scale(0) if scale is not None else weight(*idx),\n                        ),\n                    ),\n                ),\n            )\n\n            if quantize_dtype == self.storage_dtype:\n                return scaled_tensor\n\n            packed_weight = pack_weight(\n                scaled_tensor,\n                axis=-1,\n                num_elem_per_storage=self.num_elem_per_storage,\n                weight_dtype=self.weight_dtype,\n                storage_dtype=self.storage_dtype,\n            )\n\n            return packed_weight\n\n        quantized_tensor = nn.tensor_expr_op(\n            _compute_quantized_tensor, \"compute_quantized_tensor\", args=[tensor, scale]\n        )\n\n        if self.use_scale:\n            return quantized_tensor, scale\n        return (quantized_tensor,)\n\n    def _dequantize(\n        self,\n        q_weight: te.Tensor,\n        scale: Optional[te.Tensor] = None,\n        out_shape: Optional[Sequence[tir.PrimExpr]] = None,\n    ) -> te.Tensor:\n        if self.use_scale:\n            assert scale is not None\n        if DataType(self.weight_dtype).type_code in [\n            DataTypeCode.Float8E4M3FN,\n            DataTypeCode.Float8E5M2,\n        ]:\n            return self.dequantize_float8(q_weight, scale, self.weight_dtype, out_shape)\n        raise NotImplementedError()\n\n    def dequantize_float8(\n        self,\n        q_tensor: te.Tensor,\n        scale: Optional[te.Tensor],\n        quantize_dtype: str,\n        out_shape: Optional[Sequence[tir.PrimExpr]] = None,\n    ) -> te.Tensor:\n        \"\"\"Dequantize a fp8 tensor (input or weight) to higher-precision float.\"\"\"\n        if quantize_dtype != self.storage_dtype:\n            dequantized_tensor = convert_uint_packed_fp8_to_float(\n                q_tensor,\n                self.num_elem_per_storage,\n                self.storage_dtype,\n                self.model_dtype,\n                quantize_dtype,\n                axis=-1,\n                out_shape=out_shape,\n            )\n        else:\n            dequantized_tensor = q_tensor.astype(self.model_dtype)\n        if scale is not None:\n            dequantized_tensor = dequantized_tensor * scale.astype(dequantized_tensor.dtype)\n        return dequantized_tensor\n\n\nclass PerTensorQuantizeLinear(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"An nn.Linear module with per-tensor quantization.\"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments\n        self,\n        in_features: int,\n        out_features: Union[int, tir.Var],\n        config: PerTensorQuantize,\n        name: str,\n        bias: bool = True,\n        out_dtype: Optional[str] = None,\n    ) -> None:\n        super().__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.out_dtype = out_dtype or config.model_dtype\n        self.config = config\n        self.name = name\n        self.q_weight = nn.Parameter(\n            (out_features, tir.ceildiv(in_features, config.num_elem_per_storage)),\n            config.storage_dtype,\n        )\n        self.q_calibration_scale = None\n        if config.use_scale:\n            self.q_scale = nn.Parameter((1,), \"float32\")\n            if config.calibration_mode == \"inference\":\n                self.q_calibration_scale = nn.Parameter((1,), \"float32\")\n        else:\n            self.q_scale = None\n        if bias:\n            self.bias = nn.Parameter(\n                (out_features,), config.model_dtype if out_dtype is None else out_dtype\n            )\n        else:\n            self.bias = None\n\n    @classmethod\n    def from_linear(\n        cls, src: nn.Linear, config: PerTensorQuantize, name: str\n    ) -> \"PerTensorQuantizeLinear\":\n        \"\"\"\n        Converts a non-quantized nn.Linear to a per-tensor quantized PerTensorQuantizeLinear\n\n        Parameters\n        ----------\n        src : nn.Linear\n            The non-quantized nn.Linear.\n\n        config : PerTensorQuantize\n            The per-tensor quantization config.\n\n        name: str\n            The name of the layer.\n\n        Returns\n        -------\n        ret : PerTensorQuantizeLinear\n            The per-tensor quantized PerTensorQuantizeLinear layer.\n        \"\"\"\n        out_features, in_features = src.weight.shape\n        quantized_linear = cls(\n            in_features=in_features,\n            out_features=out_features,\n            config=config,\n            name=name,\n            bias=getattr(src, \"bias\", None) is not None,\n            out_dtype=src.out_dtype,\n        )\n        if quantized_linear.bias is not None:\n            quantized_linear.bias.attrs = src.bias.attrs\n        if \"shard_strategy\" in src.weight.attrs:\n            shard = src.weight.attrs[\"shard_strategy\"]\n            apply_sharding(shard, f\"{shard.name}_q_weight\", quantized_linear.q_weight)\n            # scale doesn't need to be sharded since it's the same for all shards\n        return quantized_linear\n\n    def forward(self, x: nn.Tensor) -> nn.Tensor:  # pylint: disable=invalid-name\n        \"\"\"\n        Forward method for per-tensor quantized linear layer.\n\n        Parameters\n        ----------\n        x : nn.Tensor\n            The input tensor.\n\n        Returns\n        -------\n        ret : nn.Tensor\n            The output tensor for the per-tensor quantized linear layer.\n        \"\"\"\n        # Note: Use calibration scale when calibration is enabled\n        if self.config.calibration_mode == \"inference\":\n            if self.q_calibration_scale:\n                x /= self.q_calibration_scale.astype(x.dtype)\n            x_q = x.astype(self.config.activation_dtype)\n            x_scale = self.q_calibration_scale\n        elif self.config.calibration_mode == \"max\":\n            _, x_scale = self.config.quantize_float8(  # type: ignore\n                x,\n                quantize_dtype=self.config.activation_dtype,\n                storage_dtype=self.config.storage_dtype,\n            )\n            if self.config.tensor_parallel_shards > 1:\n                x_scale = nn.ccl_allreduce(x_scale, \"max\")\n            x_scale = nn.extern(\n                \"mlc_llm.calibration_observer\",\n                [f\"{self.name}.q_calibration_scale\", \"max\", x_scale],\n                out=nn.Tensor.placeholder(x_scale.shape, x_scale.dtype),\n            )\n            x_q = (x / x_scale.astype(x.dtype)).astype(self.config.activation_dtype)\n            x = x_q.astype(self.config.model_dtype) * x_scale.astype(self.config.model_dtype)\n        else:\n            raise ValueError(f\"Unknown calibration mode: {self.config.calibration_mode}\")\n\n        if (\n            self.config.weight_dtype == self.config.storage_dtype\n            and self.config.calibration_mode == \"inference\"\n        ):\n            if (\n                extern.get_store().cutlass_gemm\n                and functools.reduce(lambda x, y: x * y, x_q.shape[:-1]) != 1\n            ):\n                # Dispatch to cutlass kernel for gemm when cutlass is available.\n                scale = (\n                    x_scale * self.q_scale\n                    if self.config.use_scale\n                    else nn.wrap_nested(\n                        relax.Constant(runtime.tensor(np.array([1.0]).astype(\"float32\"))),\n                        \"scale\",\n                    )\n                )\n                return cutlass.fp8_gemm(\n                    x_q,\n                    self.q_weight,\n                    scale,\n                    self.config.weight_dtype,\n                    self.config.model_dtype,\n                )\n            x = nn.op.matmul(x_q, nn.permute_dims(self.q_weight), out_dtype=\"float32\")\n            if self.config.use_scale:\n                scale = x_scale * self.q_scale\n                x = x * scale\n            x = x.astype(self.out_dtype)\n        else:\n            w = nn.op.tensor_expr_op(\n                lambda weight, scale: self.config._dequantize(  # pylint: disable=protected-access\n                    weight,\n                    scale,\n                    out_shape=[\n                        (\n                            tir.IntImm(\"int64\", self.out_features)\n                            if isinstance(self.out_features, int)\n                            else weight.shape[0]\n                        ),\n                        tir.IntImm(\"int64\", self.in_features),\n                    ],\n                ),\n                \"dequantize\",\n                args=[self.q_weight, self.q_scale],\n            )\n            x = nn.op.matmul(x, nn.permute_dims(w), out_dtype=self.out_dtype)\n        if self.bias is not None:\n            x = x + self.bias\n        return x\n\n    def to(self, dtype: Optional[str] = None) -> None:\n        \"\"\"\n        Override to() such that we do not convert bias if there is an out_dtype.\n        Otherwise, we might run into dtype mismatch when computing x + self.bias.\n        \"\"\"\n        self.q_weight.to(dtype=dtype)\n        if self.q_scale:\n            self.q_scale.to(dtype=dtype)\n        if self.bias is not None and self.out_dtype is None:\n            self.bias.to(dtype=dtype)\n        if dtype is not None and isinstance(getattr(self, \"dtype\", None), str):\n            self.dtype = dtype  # pylint: disable=attribute-defined-outside-init\n\n\nclass PerTensorQuantizeEmbedding(nn.Module):\n    \"\"\"An nn.Embedding module with group quantization\"\"\"\n\n    def __init__(self, num: Union[int, tir.Var], dim: int, config: PerTensorQuantize):\n        self.num = num\n        self.dim = dim\n        self.config = config\n        self.q_weight = nn.Parameter(\n            (num, tir.ceildiv(dim, config.num_elem_per_storage)), config.storage_dtype\n        )\n        if self.config.use_scale:\n            self.q_scale = nn.Parameter((1,), \"float32\")\n        else:\n            self.q_scale = None\n\n    @staticmethod\n    def from_embedding(\n        embedding: nn.Embedding, config: PerTensorQuantize\n    ) -> \"PerTensorQuantizeEmbedding\":\n        \"\"\"\n        Converts a non-quantized nn.Embedding to a per-tensor quantized PerTensorQuantizeEmbedding\n\n        Parameters\n        ----------\n        linear : nn.Embedding\n            The non-quantized nn.Embedding.\n\n        config : PerTensorQuantize\n            The per-tensor quantization config.\n\n        Returns\n        -------\n        ret : PerTensorQuantizeEmbedding\n            The per-tensor quantized embedding layer.\n        \"\"\"\n        num, dim = embedding.weight.shape\n        return PerTensorQuantizeEmbedding(num, dim, config)\n\n    def forward(self, x: nn.Tensor):  # pylint: disable=invalid-name\n        \"\"\"\n        Forward method for per-tensor quantized embedding layer.\n\n        Parameters\n        ----------\n        x : nn.Tensor\n            The input tensor.\n\n        Returns\n        -------\n        ret : nn.Tensor\n            The output tensor for the embedding layer.\n        \"\"\"\n        w = nn.op.tensor_expr_op(\n            lambda weight, scale: self.config._dequantize(  # pylint: disable=protected-access\n                weight,\n                scale,\n                out_shape=[\n                    tir.IntImm(\"int64\", self.num) if isinstance(self.num, int) else weight.shape[0],\n                    tir.IntImm(\"int64\", self.dim),\n                ],\n            ),\n            \"dequantize\",\n            args=[self.q_weight, self.q_scale],\n        )\n        if x.ndim == 1:\n            return nn.op.take(w, x, axis=0)\n        return nn.op.reshape(\n            nn.op.take(w, nn.op.reshape(x, shape=[-1]), axis=0),\n            shape=[*x.shape, self.dim],\n        )\n\n    def lm_head_forward(self, x: nn.Tensor):\n        \"\"\"The lm_head forwarding, which dequantizes the weight\n        and multiplies it with the input tensor.\n\n        Parameters\n        ----------\n        x : nn.Tensor\n            The input tensor.\n\n        Returns\n        -------\n        ret : nn.Tensor\n            The output tensor for the lm_head layer.\n        \"\"\"\n        w = nn.op.tensor_expr_op(\n            lambda weight, scale: self.config._dequantize(  # pylint: disable=protected-access\n                weight,\n                scale,\n                out_shape=[\n                    tir.IntImm(\"int64\", self.num) if isinstance(self.num, int) else weight.shape[0],\n                    tir.IntImm(\"int64\", self.dim),\n                ],\n            ),\n            \"dequantize\",\n            args=[self.q_weight, self.q_scale],\n        )\n        w = nn.op.permute_dims(w)\n        return nn.op.matmul(x, w, out_dtype=\"float32\")\n\n\nclass PerTensorQuantizeMixtralExperts(nn.Module):  # pylint: disable=too-many-instance-attributes\n    \"\"\"An MixtralExperts module with group quantization\"\"\"\n\n    _IMPL: Dict[str, Type[\"PerTensorQuantizeMixtralExperts\"]] = {}\n\n    def __init__(\n        self,\n        num_local_experts,\n        in_features,\n        out_features,\n        config: PerTensorQuantize,\n        name: str,\n    ):  # pylint: disable=too-many-arguments\n        self.num_local_experts = num_local_experts\n        self.in_features = in_features\n        self.out_features = out_features\n        self.config = config\n        self.name = name\n        self.q_weight = nn.Parameter(\n            (\n                num_local_experts,\n                out_features,\n                tir.ceildiv(in_features, config.num_elem_per_storage),\n            ),\n            config.storage_dtype,\n        )\n        self.q_calibration_scale = None\n        if config.use_scale:\n            self.q_scale = nn.Parameter((1,), \"float32\")\n            if config.calibration_mode == \"inference\":\n                self.q_calibration_scale = nn.Parameter((1,), \"float32\")\n        else:\n            self.q_scale = None\n\n    @staticmethod\n    def from_mixtral_experts(\n        src: \"MixtralExperts\",\n        config: PerTensorQuantize,\n        name: str,\n    ) -> \"PerTensorQuantizeMixtralExperts\":\n        \"\"\"\n        Converts a non-quantized MixtralExperts to a per-tensor quantized\n        PerTensorQuantizeMixtralExperts\n\n        Parameters\n        ----------\n        src : MixtralExperts\n            The non-quantized MixtralExperts\n\n        config : PerTensorQuantize\n            The per-tensor quantization config\n\n        name: str\n            The name of the layer.\n\n        Returns\n        -------\n        ret : PerTensorQuantizeMixtralExperts\n            The per-tensor quantized MixtralExperts layer\n        \"\"\"\n        if DataType(config.weight_dtype).type_code in [\n            DataTypeCode.Float8E4M3FN,\n            DataTypeCode.Float8E5M2,\n        ]:\n            return PerTensorQuantizeMixtralExperts._IMPL[\"fp8\"].from_mixtral_experts(\n                src, config, name\n            )\n        raise NotImplementedError()\n\n    def forward(self, x: nn.Tensor, indptr: nn.Tensor) -> nn.Tensor:  # pylint: disable=invalid-name\n        \"\"\"Forward method for per-tensor quantized mistral experts.\n\n        Parameters\n        ----------\n        x : nn.Tensor\n            The input tensor.\n\n        indptr: nn.Tensor\n            The indptr tensor\n\n        Returns\n        -------\n        ret : nn.Tensor\n            The output tensor for the per-tensor quantized mistral experts layer.\n        \"\"\"\n        raise NotImplementedError()\n"
  },
  {
    "path": "python/mlc_llm/quantization/quantization.py",
    "content": "\"\"\"A centralized registry of all existing quantization methods and their configurations.\"\"\"\n\nfrom typing import Any, Dict\n\nfrom .awq_quantization import AWQQuantize\nfrom .block_scale_quantization import BlockScaleQuantize\nfrom .ft_quantization import FTQuantize\nfrom .group_quantization import GroupQuantize\nfrom .no_quantization import NoQuantize\nfrom .per_tensor_quantization import PerTensorQuantize\n\nQuantization = Any\n\"\"\"Quantization is an object that represents an quantization algorithm. It is required to\nhave the following fields:\n\n    name : str\n        The name of the quantization algorithm, for example, \"q4f16_1\".\n\n    kind : str\n        The kind of quantization algorithm, for example, \"group-quant\", \"faster-transformer\".\n\nIt is also required to have the following method:\n\n    def quantize_model(self, module: nn.Module) -> nn.Module:\n        ...\n\n    def quantize_weight(self, weight: tvm.runtime.Tensor) -> List[tvm.runtime.Tensor]:\n        ...\n\"\"\"\n\nQUANTIZATION: Dict[str, Quantization] = {\n    \"q0f16\": NoQuantize(\n        name=\"q0f16\",\n        kind=\"no-quant\",\n        model_dtype=\"float16\",\n    ),\n    \"q0bf16\": NoQuantize(\n        name=\"q0bf16\",\n        kind=\"no-quant\",\n        model_dtype=\"bfloat16\",\n    ),\n    \"q0f32\": NoQuantize(\n        name=\"q0f32\",\n        kind=\"no-quant\",\n        model_dtype=\"float32\",\n    ),\n    \"q3f16_0\": GroupQuantize(\n        name=\"q3f16_0\",\n        kind=\"group-quant\",\n        group_size=40,\n        quantize_dtype=\"int3\",\n        storage_dtype=\"uint32\",\n        model_dtype=\"float16\",\n        linear_weight_layout=\"KN\",\n        quantize_embedding=True,\n        quantize_final_fc=True,\n    ),\n    \"q3f16_1\": GroupQuantize(\n        name=\"q3f16_1\",\n        kind=\"group-quant\",\n        group_size=40,\n        quantize_dtype=\"int3\",\n        storage_dtype=\"uint32\",\n        model_dtype=\"float16\",\n        linear_weight_layout=\"NK\",\n        quantize_embedding=True,\n        quantize_final_fc=True,\n    ),\n    \"q4f16_0\": GroupQuantize(\n        name=\"q4f16_0\",\n        kind=\"group-quant\",\n        group_size=32,\n        quantize_dtype=\"int4\",\n        storage_dtype=\"uint32\",\n        model_dtype=\"float16\",\n        linear_weight_layout=\"KN\",\n        quantize_embedding=True,\n        quantize_final_fc=True,\n    ),\n    \"q4f16_1\": GroupQuantize(\n        name=\"q4f16_1\",\n        kind=\"group-quant\",\n        group_size=32,\n        quantize_dtype=\"int4\",\n        storage_dtype=\"uint32\",\n        model_dtype=\"float16\",\n        linear_weight_layout=\"NK\",\n        quantize_embedding=True,\n        quantize_final_fc=True,\n    ),\n    \"q4bf16_0\": GroupQuantize(\n        name=\"q4bf16_0\",\n        kind=\"group-quant\",\n        group_size=32,\n        quantize_dtype=\"int4\",\n        storage_dtype=\"uint32\",\n        model_dtype=\"bfloat16\",\n        linear_weight_layout=\"KN\",\n        quantize_embedding=True,\n        quantize_final_fc=True,\n    ),\n    \"q4bf16_1\": GroupQuantize(\n        name=\"q4bf16_1\",\n        kind=\"group-quant\",\n        group_size=32,\n        quantize_dtype=\"int4\",\n        storage_dtype=\"uint32\",\n        model_dtype=\"bfloat16\",\n        linear_weight_layout=\"NK\",\n        quantize_embedding=True,\n        quantize_final_fc=True,\n    ),\n    \"q4f32_1\": GroupQuantize(\n        name=\"q4f32_1\",\n        kind=\"group-quant\",\n        group_size=32,\n        quantize_dtype=\"int4\",\n        storage_dtype=\"uint32\",\n        model_dtype=\"float32\",\n        linear_weight_layout=\"NK\",\n        quantize_embedding=True,\n        quantize_final_fc=True,\n    ),\n    \"q4f16_2\": GroupQuantize(\n        name=\"q4f16_2\",\n        kind=\"group-quant\",\n        group_size=32,\n        quantize_dtype=\"int4\",\n        storage_dtype=\"uint32\",\n        model_dtype=\"float16\",\n        linear_weight_layout=\"NK\",\n        quantize_embedding=False,\n        quantize_final_fc=False,\n    ),\n    \"q4f16_autoawq\": AWQQuantize(\n        name=\"q4f16_autoawq\",\n        kind=\"awq\",\n        group_size=128,\n        quantize_dtype=\"int4\",\n        storage_dtype=\"uint32\",\n        model_dtype=\"float16\",\n    ),\n    \"q4f16_ft\": FTQuantize(\n        name=\"q4f16_ft\",\n        kind=\"ft-quant\",\n        quantize_dtype=\"int4\",\n        storage_dtype=\"int8\",\n        model_dtype=\"float16\",\n    ),\n    \"e5m2_e5m2_f16\": PerTensorQuantize(\n        name=\"e5m2_e5m2_f16\",\n        kind=\"per-tensor-quant\",\n        activation_dtype=\"float8_e5m2\",\n        weight_dtype=\"float8_e5m2\",\n        storage_dtype=\"float8_e5m2\",\n        model_dtype=\"float16\",\n        quantize_final_fc=False,\n        quantize_embedding=False,\n        quantize_linear=True,\n        use_scale=False,\n    ),\n    \"e4m3_e4m3_f16\": PerTensorQuantize(\n        name=\"e4m3_e4m3_f16\",\n        kind=\"per-tensor-quant\",\n        activation_dtype=\"float8_e4m3fn\",\n        weight_dtype=\"float8_e4m3fn\",\n        storage_dtype=\"float8_e4m3fn\",\n        model_dtype=\"float16\",\n        quantize_final_fc=False,\n        quantize_embedding=False,\n        quantize_linear=True,\n        use_scale=True,\n        calibration_mode=\"inference\",\n    ),\n    \"e4m3_e4m3_f16_max_calibrate\": PerTensorQuantize(\n        name=\"e4m3_e4m3_f16_max_calibrate\",\n        kind=\"per-tensor-quant\",\n        activation_dtype=\"float8_e4m3fn\",\n        weight_dtype=\"float8_e4m3fn\",\n        storage_dtype=\"float8_e4m3fn\",\n        model_dtype=\"float16\",\n        quantize_final_fc=False,\n        quantize_embedding=False,\n        quantize_linear=True,\n        use_scale=True,\n        calibration_mode=\"max\",\n    ),\n    \"fp8_e4m3fn_bf16_block_scale\": BlockScaleQuantize(\n        name=\"fp8_e4m3fn_bf16_block_scale\",\n        kind=\"block-scale-quant\",\n        weight_dtype=\"float8_e4m3fn\",\n        model_dtype=\"bfloat16\",\n    ),\n    \"fp8_e4m3fn_bf16_block_scale_static_activation\": BlockScaleQuantize(\n        name=\"fp8_e4m3fn_bf16_block_scale_static_activation\",\n        kind=\"block-scale-quant\",\n        weight_dtype=\"float8_e4m3fn\",\n        model_dtype=\"bfloat16\",\n        use_activation_scale=True,\n    ),\n}\n"
  },
  {
    "path": "python/mlc_llm/quantization/utils.py",
    "content": "\"\"\"Common utilities for quantization\"\"\"\n\nfrom typing import Callable, List, Optional, Sequence\n\nfrom tvm import IRModule, relax, te, tir\nfrom tvm.relax.frontend import nn\nfrom tvm.runtime import DataType, DataTypeCode\nfrom tvm.s_tir import dlight as dl\nfrom tvm.target import Target\n\nfrom mlc_llm.support import tensor_parallel as tp\n\n\ndef convert_uint_to_float(  # pylint: disable=too-many-arguments\n    weight: te.Tensor,\n    bits: int,\n    num_elem_per_storage: int,\n    storage_dtype: str,\n    model_dtype: str,\n    axis: int = -1,\n    out_shape: Optional[List[tir.PrimExpr]] = None,\n    ft_reorder: Optional[bool] = False,\n) -> te.Tensor:\n    \"\"\"Convert a quantized uint weight to an unquantized float weight.\"\"\"\n    tir_bin_mask = tir.const((1 << bits) - 1, storage_dtype)\n    if out_shape is None:\n        out_shape = weight.shape\n        out_shape[axis] *= num_elem_per_storage\n    axis = axis if axis >= 0 else len(out_shape) + axis\n    return te.compute(\n        shape=out_shape,\n        fcompute=lambda *idx: tir.bitwise_and(\n            tir.shift_right(\n                weight(*idx[:axis], idx[axis] // num_elem_per_storage, *idx[axis + 1 :]),\n                (\n                    (\n                        (idx[axis] % num_elem_per_storage) % 2 * 4\n                        + (idx[axis] % num_elem_per_storage) // 2\n                    )\n                    * bits\n                    if ft_reorder\n                    else (idx[axis] % num_elem_per_storage) * bits\n                ).astype(storage_dtype),\n            ),\n            tir_bin_mask,\n        ).astype(model_dtype),\n    )\n\n\ndef is_final_fc(name: str) -> bool:\n    \"\"\"Determines whether the parameter is the last layer based on its name.\"\"\"\n    # TODO: use more specious condition to determine final fc  # pylint: disable=fixme\n    return name in [\"head\", \"lm_head\", \"lm_head.linear\", \"embed_out\"]\n\n\ndef is_moe_gate(name: str, node: nn.Linear) -> bool:\n    \"\"\"Check whether the parameter is the MoE gate layer.\"\"\"\n    return name.endswith(\"gate\") and isinstance(node.out_features, int) and node.out_features <= 256\n\n\ndef compile_quantize_func(mod: IRModule, device) -> Callable:\n    \"\"\"Compile a quantization function for a given device.\"\"\"\n    device_type = device._DEVICE_TYPE_TO_NAME[  # pylint: disable=protected-access\n        device.dlpack_device_type()\n    ]\n    if device_type in [\"cuda\", \"rocm\", \"metal\", \"vulkan\", \"opencl\"]:\n        target = Target.current()\n        if target is None:\n            target = Target.from_device(device)\n        with target:\n            mod = dl.ApplyDefaultSchedule(  # type: ignore   # pylint: disable=not-callable\n                dl.gpu.Reduction(),\n                dl.gpu.GeneralReduction(),\n                dl.gpu.Fallback(),\n            )(mod)\n    elif device_type == \"cpu\":\n        target = \"llvm\"\n        mod = relax.transform.LegalizeOps()(mod)\n    else:\n        raise NotImplementedError(f\"Device type {device_type} is not supported\")\n    ex = relax.build(mod, target=target)\n    vm = relax.VirtualMachine(ex, device)  # pylint: disable=invalid-name\n    return vm[\"main\"]\n\n\ndef apply_sharding(shard_strategy, name: str, weight: nn.Parameter):\n    \"\"\"Apply sharding strategy to a weight.\"\"\"\n    if isinstance(shard_strategy, tp.ShardSingleDim):\n        weight.attrs[\"shard_strategy\"] = tp.ShardSingleDim(\n            name=name,\n            dim=shard_strategy.dim,\n            segs=shard_strategy.segs,\n        )\n    else:\n        raise NotImplementedError(f\"Unknowing sharding strategy: {shard_strategy}\")\n\n\ndef convert_uint_packed_fp8_to_float(  # pylint: disable=too-many-arguments\n    weight: te.Tensor,\n    num_elem_per_storage: int,\n    storage_dtype: str,\n    model_dtype: str,\n    quant_dtype: str,\n    axis: int = -1,\n    out_shape: Optional[Sequence[tir.PrimExpr]] = None,\n) -> te.Tensor:\n    \"\"\"Unpack a fp8 value from the storage dtype and convert to float.\"\"\"\n    assert quant_dtype in [\"float8_e4m3fn\", \"float8_e5m2\"]\n    assert DataType(storage_dtype).type_code == DataTypeCode.UINT\n    bits = DataType(quant_dtype).bits\n    elem_storage_dtype = DataType(f\"uint{bits}\")\n    tir_bin_mask = tir.const((1 << bits) - 1, \"uint8\")\n    if axis < 0:\n        axis += len(weight.shape)\n    if out_shape is None:\n        out_shape = (\n            *weight.shape[:axis],\n            weight.shape[axis] * num_elem_per_storage,\n            *weight.shape[axis + 1 :],\n        )\n    axis = axis if axis >= 0 else len(out_shape) + axis\n    return te.compute(\n        shape=out_shape,\n        fcompute=lambda *idx: tir.reinterpret(\n            quant_dtype,\n            tir.bitwise_and(\n                tir.shift_right(\n                    weight(*idx[:axis], idx[axis] // num_elem_per_storage, *idx[axis + 1 :]),\n                    ((idx[axis] % num_elem_per_storage) * bits).astype(storage_dtype),\n                ).astype(elem_storage_dtype),\n                tir_bin_mask,\n            ),\n        ).astype(model_dtype),\n    )\n\n\ndef pack_weight(\n    weight: te.Tensor,\n    axis: int,\n    num_elem_per_storage: int,\n    weight_dtype: str,\n    storage_dtype: str,\n    out_shape: Optional[Sequence[tir.PrimExpr]] = None,\n):  # pylint: disable=too-many-arguments\n    \"\"\"Convert a tensor to a packed format by packing consecutive bits.\n    This can be useful for sub-byte quantization.\n\n    Parameters\n    ----------\n    weight : te.Tensor\n        The weight\n    axis : int\n        The axis to pack.\n    num_elem_per_storage : int\n        The number of elements per storage.\n    weight_dtype : str\n        The dtype of the input tensor.\n    storage_dtype : str\n        The dtype of the packed tensor.\n    out_shape : Optional[Sequence[tir.PrimExpr]]\n        The output shape of the packed tensor. Zero-padding is added if needed.\n    \"\"\"\n    assert weight.dtype == storage_dtype\n    shape = weight.shape\n    if axis < 0:\n        axis += len(shape)\n    k = shape[axis]\n    axis = axis if axis >= 0 else len(shape) + axis\n    if out_shape is None:\n        out_shape = (\n            *shape[:axis],\n            tir.ceildiv(k, num_elem_per_storage),\n            *shape[axis + 1 :],\n        )\n    r = te.reduce_axis((0, num_elem_per_storage), name=\"r\")  # pylint: disable=invalid-name\n    packed_weight = te.compute(\n        shape=out_shape,\n        fcompute=lambda *idx: tir.sum(\n            tir.if_then_else(\n                idx[axis] * num_elem_per_storage + r < k,\n                weight(*idx[:axis], idx[axis] * num_elem_per_storage + r, *idx[axis + 1 :])\n                << (r * DataType(weight_dtype).bits),\n                tir.const(0, storage_dtype),\n            ),\n            axis=r,\n        ),\n        name=\"packed_weight\",\n    ).astype(storage_dtype)\n    return packed_weight\n"
  },
  {
    "path": "python/mlc_llm/router/__init__.py",
    "content": "\"\"\"Subdirectory of router, which routes to multiple engine endpoints.\"\"\"\n\nfrom .. import base\nfrom .router import Router\n"
  },
  {
    "path": "python/mlc_llm/router/router.py",
    "content": "\"\"\"Programmable router for dispatching OpenAI API to Microserving API\"\"\"\n\nimport json\nimport math\nimport threading\nfrom typing import Any, AsyncGenerator, Iterable, List, Literal, Optional, Tuple\n\nimport aiohttp  # pylint: disable=import-error\nimport tvm\n\nfrom mlc_llm.protocol import openai_api_protocol\nfrom mlc_llm.serve import EngineConfig, PopenServer\nfrom mlc_llm.serve.entrypoints import microserving_entrypoints\nfrom mlc_llm.tokenizers import Tokenizer\n\n\nclass Router:  # pylint: disable=too-many-instance-attributes\n    \"\"\"Programmable Router Implementation\"\"\"\n\n    def __init__(\n        self,\n        model: str,\n        model_lib: Optional[str] = None,\n        hosts: Optional[List[str]] = None,\n        ports: Optional[List[int]] = None,\n        num_gpus: Optional[List[int]] = None,\n        enable_prefix_cache: bool = False,\n        router_mode: Literal[\"disagg\", \"round-robin\"] = \"disagg\",\n        pd_balance_factor: float = 0.0,\n    ):  # pylint: disable=too-many-arguments,too-many-locals\n        \"\"\"\n        Spawn len(host_list) server endpoints with Popen.\n        \"\"\"\n        if hosts is None:\n            hosts = [\"127.0.0.1\"]\n        if ports is None:\n            ports = [8080]\n        if num_gpus is None:\n            num_gpus = [1]\n\n        self.router_mode = router_mode\n        self.pd_balance_factor = pd_balance_factor\n        # Get endpoint urls\n        self.num_servers = len(hosts)\n        assert self.num_servers == len(ports) == len(num_gpus)\n        self.hosts = hosts\n        self.ports = ports\n        self.server_urls = []\n        for i in range(self.num_servers):\n            self.server_urls.append(f\"http://{hosts[i]}:{ports[i]}\")\n\n        # Misc\n        self.headers = {\"Content-Type\": \"application/json\"}\n        self.num_running_requests = [0] * self.num_servers\n\n        # Call nvshmem_init here to get uid, then pass to env variables to server.start() below\n        f_init_nvshmem_uid = tvm.get_global_func(\"runtime.disco.nvshmem.init_nvshmem_uid\")\n        uid = list(f_init_nvshmem_uid())\n\n        # Start underlying servers concurrently. Otherwise 1 server cannot start on its own\n        # since initializing nvhsmem world requires all GPUs.\n        self.servers: List[PopenServer] = []\n\n        self.device_id_starts = [0]\n        for num_gpus_val in num_gpus:\n            self.device_id_starts.append(self.device_id_starts[-1] + num_gpus_val)\n        # device_id_starts[-1] is the total number of GPUs.\n\n        def start_server(i: int):\n            nvshmem_config = {\n                \"uid\": uid,\n                \"npes\": self.device_id_starts[-1],  # total number of workers in the nvshmem world\n                \"pe_start\": self.device_id_starts[i],  # start of PE for this endpoint's workers\n            }\n\n            server = PopenServer(\n                model=model,\n                model_lib=model_lib,\n                host=hosts[i],\n                port=ports[i],\n                enable_debug=True,\n                device=f\"cuda:{self.device_id_starts[i]}\",\n                mode=\"server\",\n                engine_config=EngineConfig(\n                    prefix_cache_mode=\"radix\" if enable_prefix_cache else \"disable\",\n                    gpu_memory_utilization=0.8,\n                ),\n            )\n            self.servers.append(server)\n            server.start(extra_env={\"MLC_NVSHMEM_INIT_CONFIG_JSON_STR\": json.dumps(nvshmem_config)})\n\n        threads = []\n        num_used_gpus = 0\n        for i in range(self.num_servers):\n            thread = threading.Thread(\n                target=start_server,\n                args=[i],\n            )\n            num_used_gpus += num_gpus[i]\n            thread.start()\n            threads.append(thread)\n        for thread in threads:\n            thread.join()\n        self.tokenizer = Tokenizer(model)\n\n    def terminate(self):\n        \"\"\"Terminate the underlying servers\"\"\"\n        for server in self.servers:\n            server.terminate()\n\n    async def handle_completion(\n        self,\n        request: openai_api_protocol.CompletionRequest,\n        request_id: str,\n    ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]:\n        \"\"\"\n        Handle a completion request from API with a schedule.\n        \"\"\"\n        if isinstance(request.prompt, str):\n            request.prompt = self.tokenizer.encode(request.prompt)\n        # Add a debugConfig if not present\n        if request.debug_config is None:\n            request.debug_config = openai_api_protocol.DebugConfig()\n        completed = False\n        while not completed:\n            completed = True\n            async for response in self.translate_request(request, request_id):\n                if response is None:\n                    completed = False\n                    break\n                yield response\n\n    async def translate_request(\n        self, request: openai_api_protocol.CompletionRequest, request_id: str\n    ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]:\n        \"\"\"\n        Translate OpenAI API request to microserving API calls.\n        \"\"\"\n        if self.router_mode == \"disagg\":\n            async for response in self._handle_completion_disagg(\n                request, request_id, pd_balance_factor=self.pd_balance_factor\n            ):\n                yield response\n        elif self.router_mode == \"round-robin\":\n            async for response in self._handle_completion_round_robin(request):\n                yield response\n        else:\n            raise ValueError(\"Cannot reach here\")\n\n    def _pick_endpoint(self, endpoint_ids: Iterable[int]) -> int:\n        # Pick the least congested endpoint.\n        endpoint_id = -1\n        min_running_req = int(1e9)\n        for candidate_id in endpoint_ids:\n            if self.num_running_requests[candidate_id] < min_running_req:\n                min_running_req = self.num_running_requests[candidate_id]\n                endpoint_id = candidate_id\n        assert endpoint_id != -1\n        return endpoint_id\n\n    async def _handle_completion_round_robin(\n        self,\n        request: openai_api_protocol.CompletionRequest,\n    ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]:\n        \"\"\"\n        Handle a completion request from API. Given a streaming request, yields multiple response\n        chunks. Given a non-streaming request, yield a single response. Dispatch request to\n        endpoints with round-robin scheduling at a request level.\n        \"\"\"\n        # Round robin\n        cur_endpoint = self._pick_endpoint(range(self.num_servers))\n        self.num_running_requests[cur_endpoint] += 1\n        payload = request.model_dump()\n        async with aiohttp.ClientSession(\n            timeout=aiohttp.ClientTimeout(total=3 * 3600), trust_env=True\n        ) as session:\n            # pylint: disable=fixme\n            # todo: replace this with start_generate\n            # pylint: enable=fixme\n            async with session.post(\n                self.server_urls[cur_endpoint] + \"/v1/completions\",\n                json=payload,\n                headers=self.headers,\n            ) as response:\n                assert response.status == 200, await response.text()\n                if payload[\"stream\"]:\n                    async for chunk in response.content:\n                        # Convert raw bytes to CompletionResponse\n                        chunk = chunk.strip()\n                        if not chunk or chunk == b\"\\n\":\n                            continue\n                        # Get rid of the prefix \"data: \" and suffix \"\\n\"\n                        raw_data = chunk[6:].strip()\n                        if raw_data == b\"[DONE]\":\n                            continue\n                        data = json.loads(raw_data)\n                        # Commented because we still want usage chunk to be passed back\n                        # if not data[\"choices\"]:\n                        #     continue\n                        response = openai_api_protocol.CompletionResponse.model_validate(data)\n                        if response.choices:\n                            reason = response.choices[0].finish_reason\n                            if reason == \"preempt\":\n                                yield None\n                        yield response\n                else:\n                    data = await response.json()\n                    response = openai_api_protocol.CompletionResponse.model_validate(data)\n                    if response.choices:\n                        reason = response.choices[0].finish_reason\n                        if reason == \"preempt\":\n                            yield None\n                    yield response\n            self.num_running_requests[cur_endpoint] -= 1\n\n    #\n    # Below methods are for disaggregated serving\n    # Note that only _handle_completion_disagg() has scheduling logics. The other three\n    # helper methods only reflect our flow.\n    #\n    async def _handle_completion_disagg(  # pylint: disable=too-many-locals\n        self,\n        original_request: openai_api_protocol.CompletionRequest,\n        request_id: str,\n        pd_balance_factor=0,\n    ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]:\n        \"\"\"\n        Handle a completion request from API with disaggregated scheduling. Given two servers\n        P (prefill) and D (decode), the router does the following:\n            1. Ask D to prepare metadata, receive D's metadata\n            (prefix cache, KV append positions, etc.)\n            2. Send P the prefill request and D's metadata, receive ack\n            3. Ask D to start decoding, receive response as a normal streaming\n        \"\"\"\n        original_request.user = request_id\n        # Arbitrarily determine server 0 is P, other servers are D\n        prefill_server_id = 0\n        decode_server_id = self._pick_endpoint(range(1, self.num_servers))\n\n        # Tell D to prepare metadata for prompt[0:kv_window_end].\n        # P does not need to sample. Ask D to treat the last\n        # token like the first sampled token.\n        kv_window_end = (\n            -1\n            if math.fabs(pd_balance_factor) < 1e-5\n            else int((1 - pd_balance_factor) * len(original_request.prompt))\n        )\n        async with aiohttp.ClientSession(\n            timeout=aiohttp.ClientTimeout(total=3 * 3600), trust_env=True\n        ) as session:\n            self.num_running_requests[decode_server_id] += 1\n            try:\n                # 1. Ask D to prepare metadata\n                prep_recv_request = microserving_entrypoints.PrepRecvRequest(\n                    **original_request.model_dump(), end=kv_window_end\n                )\n                (\n                    kv_append_metadata_base64,\n                    prefix_matched_length,\n                ) = await self.send_prepare_receive(\n                    session=session,\n                    request=prep_recv_request,\n                    server_url=self.server_urls[decode_server_id],\n                )\n\n                kv_window_end = (\n                    len(original_request.prompt) + kv_window_end\n                    if kv_window_end < 0\n                    else kv_window_end\n                )\n                assert prefix_matched_length <= kv_window_end\n\n                # 2. Send P the prefill request and D's metadata. When it returns, it means that\n                # KV transfer has finished prefilling and transferring the KV of\n                # prompt[prefix_matched_length:kv_window_end]. So D is ready to decode.\n                if prefix_matched_length < kv_window_end:\n                    remote_send_request = microserving_entrypoints.RemoteSendRequest(\n                        **original_request.model_dump(),\n                        begin=prefix_matched_length,\n                        end=kv_window_end,\n                        kv_addr_info=kv_append_metadata_base64,\n                        recv_rank=self.device_id_starts[decode_server_id],\n                    )\n                    await self.send_remote_send(\n                        session=session,\n                        request=remote_send_request,\n                        server_url=self.server_urls[prefill_server_id],\n                    )\n\n                # 3. Start decoding, receive and yield back response as a normal request\n                # The kv window passed through denotes the range to prefill on the\n                # decode server, which should be [-1:] here.\n                start_generate_request = microserving_entrypoints.StartGenerateRequest(\n                    **original_request.model_dump(),\n                    begin=kv_window_end,\n                )\n                async for response in self.send_start_generate(\n                    session=session,\n                    request=start_generate_request,\n                    server_url=self.server_urls[decode_server_id],\n                ):\n                    if len(response.choices) > 0:\n                        finish_reason = response.choices[0].finish_reason\n                        if finish_reason == \"preempt\":\n                            yield None\n                    yield response\n            except Exception as e:\n                self.num_running_requests[decode_server_id] -= 1\n                raise e\n            self.num_running_requests[decode_server_id] -= 1\n\n    async def send_prepare_receive(\n        self,\n        session: aiohttp.ClientSession,\n        request: openai_api_protocol.CompletionRequest,\n        server_url: str,\n    ) -> Tuple[str, int]:\n        \"\"\"\n        Performs step 1 of disaggregated serving: ask D to prepare metadata.\n        Returns:\n            The metadata received from D, which is a tuple of 2 elements:\n                - kv_append_metadata_base64: str, info about KV append encoded in base64 string\n                - prefix_matched_length: int, length of the matched prefix.\n                    i.e. prompt[0:prefix_matched_length] is the matched prefix\n        \"\"\"\n        # Send request to the decode server for receive preparation.\n        # Get the prompt length, matched prefix length and the KV metadata.\n        async with session.post(\n            server_url + \"/microserving/prep_recv\",\n            json=request.model_dump(),\n            headers=self.headers,\n        ) as response:\n            assert response.status == 200, await response.text()\n            data = await response.json()\n\n            return (\n                data[\"kv_append_metadata\"],\n                data[\"prefix_matched_length\"],\n            )\n\n    async def send_remote_send(\n        self,\n        session: aiohttp.ClientSession,\n        request: openai_api_protocol.CompletionRequest,\n        server_url: str,\n    ) -> None:\n        \"\"\"\n        Performs step 2 of disaggregated serving: ask P to prefill and transfer KV to D.\n        P returns an empty chunk to acknowledge completion.\n        \"\"\"\n        # Send request to P and get ack\n        async with session.post(\n            server_url + \"/microserving/remote_send\",\n            json=request.model_dump(),\n            headers=self.headers,\n        ) as response:\n            assert response.status == 200, await response.text()\n            await response.json()\n\n    async def send_start_generate(\n        self,\n        session: aiohttp.ClientSession,\n        request: openai_api_protocol.CompletionRequest,\n        server_url: str,\n    ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]:\n        \"\"\"\n        Performs step 3 of disaggregated serving: ask D to decode and return normal response.\n        \"\"\"\n        # pylint: disable=fixme\n        # Todo: return string directly to reduce str->json->str roundtrip overhead\n        # pylint: enable=fixme\n        async with session.post(\n            server_url + \"/microserving/start_generate\",\n            json=request.model_dump(),\n            headers=self.headers,\n        ) as response:\n            assert response.status == 200, await response.text()\n            if request.stream:\n                async for chunk in response.content:\n                    # Convert raw bytes to CompletionResponse\n                    chunk = chunk.strip()\n                    if not chunk or chunk == b\"\\n\":\n                        continue\n                    # Get rid of the prefix \"data: \" and suffix \"\\n\"\n                    raw_data = chunk[6:].strip()\n                    if raw_data == b\"[DONE]\":\n                        continue\n                    data = json.loads(raw_data)\n                    # Commented because we still want usage chunk to be passed back\n                    # if not data[\"choices\"]:\n                    #     continue\n                    yield openai_api_protocol.CompletionResponse.model_validate(data)\n            else:\n                data = await response.json()\n                yield openai_api_protocol.CompletionResponse.model_validate(data)\n"
  },
  {
    "path": "python/mlc_llm/serve/__init__.py",
    "content": "\"\"\"Subdirectory of serving.\"\"\"\n\n# Load MLC LLM library by importing base\nfrom .. import base\nfrom .config import EngineConfig\nfrom .data import Data, ImageData, RequestStreamOutput, TextData, TokenData\nfrom .embedding_engine import AsyncEmbeddingEngine\nfrom .engine import AsyncMLCEngine, MLCEngine\nfrom .radix_tree import PagedRadixTree\nfrom .request import Request\nfrom .server import PopenServer\n"
  },
  {
    "path": "python/mlc_llm/serve/_ffi_api.py",
    "content": "\"\"\"FFI APIs for mlc_llm.serve\"\"\"\n\nimport tvm_ffi\n\n# Exports functions registered via TVM_FFI_REGISTER_GLOBAL with the \"mlc.serve\" prefix.\n# e.g. TVM_FFI_REGISTER_GLOBAL(\"mlc.serve.TextData\")\ntvm_ffi.init_ffi_api(\"mlc.serve\", __name__)  # pylint: disable=protected-access\n"
  },
  {
    "path": "python/mlc_llm/serve/config.py",
    "content": "\"\"\"Configuration dataclasses used in MLC LLM serving\"\"\"\n\nimport json\nfrom dataclasses import asdict, dataclass, field\nfrom typing import List, Literal, Optional, Tuple, Union\n\n\n@dataclass\nclass EngineConfig:  # pylint: disable=too-many-instance-attributes\n    \"\"\"The class of MLCEngine execution configuration.\n\n    Parameters\n    ----------\n    model : str\n        The path to the model directory.\n\n    model_lib : str\n        The path to the model library.\n\n    additional_models : List[Union[str, Tuple[str, str]]]\n        The paths to the additional models' directories (and model libraries).\n        Each element is a single string (denoting the model directory)\n        or a tuple of two strings (denoting the model directory and model lib path).\n\n    mode : Literal[\"local\", \"interactive\", \"server\"]\n        The engine mode in MLC LLM.\n        We provide three preset modes: \"local\", \"interactive\" and \"server\".\n        The default mode is \"local\".\n        The choice of mode decides the values of \"max_num_sequence\", \"max_total_sequence_length\"\n        and \"prefill_chunk_size\" when they are not explicitly specified.\n        1. Mode \"local\" refers to the local server deployment which has low\n        request concurrency. So the max batch size will be set to 4, and max\n        total sequence length and prefill chunk size are set to the context\n        window size (or sliding window size) of the model.\n        2. Mode \"interactive\" refers to the interactive use of server, which\n        has at most 1 concurrent request. So the max batch size will be set to 1,\n        and max total sequence length and prefill chunk size are set to the context\n        window size (or sliding window size) of the model.\n        3. Mode \"server\" refers to the large server use case which may handle\n        many concurrent request and want to use GPU memory as much as possible.\n        In this mode, we will automatically infer the largest possible max batch\n        size and max total sequence length.\n\n        You can manually specify arguments \"max_num_sequence\", \"max_total_sequence_length\" and\n        \"prefill_chunk_size\" to override the automatic inferred values.\n\n    tensor_parallel_shards : Optional[int]\n        Number of shards to split the model into in tensor parallelism multi-gpu inference.\n        When \"model_lib\" is given, this field will be ignored, and the tensor_parallel_shards\n        in the model_lib metadata will be used.\n\n    pipeline_parallel_stages : Optional[int]\n        Number of pipeline stages to split the model layers for pipeline parallelism.\n        When \"model_lib\" is given, this field will be ignored, and the pipeline_parallel_stages\n        in the model_lib metadata will be used.\n\n    opt : Optional[str]\n        The optimization flags for JIT compilation.\n        When \"model_lib\" is given, this field will be ignored.\n        MLC LLM maintains a predefined set of optimization flags,\n        denoted as O0, O1, O2, O3, where O0 means no optimization, O2 means majority of them,\n        and O3 represents extreme optimization that could potentially break the system.\n        Meanwhile, optimization flags could be explicitly specified via details knobs, e.g.\n        \"cublas_gemm=1;cudagraph=0\".\n\n    gpu_memory_utilization : Optional[float]\n        A number in (0, 1) denoting the fraction of GPU memory used by the server in total.\n        It is used to infer to maximum possible KV cache capacity.\n        When it is unspecified, it defaults to 0.85.\n        Under mode \"local\" or \"interactive\", the actual memory usage may be\n        significantly smaller than this number. Under mode \"server\", the actual\n        memory usage may be slightly larger than this number.\n\n    kv_cache_page_size : int\n        The number of consecutive tokens handled in each page in paged KV cache.\n\n    max_num_sequence : Optional[int]\n        The maximum number of sequences that are allowed to be\n        processed by the KV cache at any time.\n\n    max_total_sequence_length : Optional[int]\n        The maximum total number of tokens whose KV data are allowed\n        to exist in the KV cache at any time.\n\n    max_single_sequence_length : Optional[int]\n        The maximum length allowed for a single sequence in the engine.\n\n    prefill_chunk_size : Optional[int]\n        The maximum total sequence length in a prefill.\n\n    sliding_window_size : Optional[int]\n        The sliding window size in sliding window attention (SWA).\n\n    attention_sink_size : Optional[int]\n        The number of attention sinks when sliding window is enabled..\n\n    max_history_size: Optional[int]\n        The maximum history size for RNN state to roll back.\n\n    kv_state_kind: Optional[Literal[\"kv_cache\", \"rnn_state\"]]\n        The kind of cache.\n\n    speculative_mode : Literal[\"disable\", \"small_draft\", \"eagle\", \"medusa\"]\n        The speculative mode.\n        \"disable\" means speculative decoding is disabled.\n        \"small_draft\" means the normal speculative decoding (small draft) mode.\n        \"eagle\" means the eagle-style speculative decoding.\n        \"medusa\" means the medusa-style speculative decoding.\n\n    spec_draft_length : int\n        The number of tokens to generate in speculative proposal (draft).\n        Being 0 means to enable adaptive speculative mode, where the draft length\n        will be automatically adjusted based on engine state.\n\n    spec_tree_width : int\n        The width of the speculative decoding tree.\n\n    prefix_cache_mode : Literal[\"disable\", \"radix\"]\n        The prefix cache mode.\n        \"disable\" means no prefix cache is disabled.\n        \"radix\" means the paged radix tree based prefix cache mode.\n\n    prefix_cache_max_num_recycling_seqs: Optional[int]\n        The maximum number of recycling sequences in prefix cache, default as max_num_sequence.\n        And set 0 to disable prefix cache, set -1 to have infinite capacity prefix cache.\n\n    prefill_mode : Literal[\"chunked\", \"hybrid\"]\n        The prefill mode.\n        \"chunked\" means the basic prefill with chunked input enabled.\n        \"hybrid\" means the hybrid prefill or split-fuse,\n        so that decode step will be converted into prefill.\n\n    verbose : bool\n        A boolean indicating whether to print logging info in engine.\n    \"\"\"\n\n    model: Optional[str] = None\n    model_lib: Optional[str] = None\n    additional_models: List[Union[str, Tuple[str, str]]] = field(default_factory=list)\n    mode: Optional[Literal[\"local\", \"interactive\", \"server\"]] = None\n    tensor_parallel_shards: Optional[int] = None\n    pipeline_parallel_stages: Optional[int] = None\n    opt: Optional[str] = None\n    gpu_memory_utilization: Optional[float] = None\n    kv_cache_page_size: int = 16\n    max_num_sequence: Optional[int] = None\n    max_total_sequence_length: Optional[int] = None\n    max_single_sequence_length: Optional[int] = None\n    prefill_chunk_size: Optional[int] = None\n    sliding_window_size: Optional[int] = None\n    attention_sink_size: Optional[int] = None\n    max_history_size: Optional[int] = None\n    kv_state_kind: Optional[Literal[\"kv_cache\", \"rnn_state\"]] = None\n    speculative_mode: Literal[\"disable\", \"small_draft\", \"eagle\", \"medusa\"] = \"disable\"\n    spec_draft_length: int = 0\n    spec_tree_width: int = 1\n    prefix_cache_mode: Literal[\"disable\", \"radix\"] = \"radix\"\n    prefix_cache_max_num_recycling_seqs: Optional[int] = None\n    prefill_mode: Literal[\"chunked\", \"hybrid\"] = \"hybrid\"\n    verbose: bool = True\n\n    def asjson(self) -> str:\n        \"\"\"Return the config in string of JSON format.\"\"\"\n        return json.dumps(asdict(self))\n\n    @staticmethod\n    def from_json(json_str: str) -> \"EngineConfig\":\n        \"\"\"Construct a config from JSON string.\"\"\"\n        return EngineConfig(**json.loads(json_str))\n"
  },
  {
    "path": "python/mlc_llm/serve/data.py",
    "content": "\"\"\"Classes denoting multi-modality data used in MLC LLM serving\"\"\"\n\nfrom dataclasses import dataclass\nfrom typing import Dict, List, Optional, Tuple\n\nimport tvm\nimport tvm_ffi\nfrom tvm.runtime import Object, Tensor\n\nfrom . import _ffi_api\n\n\n@tvm_ffi.register_object(\"mlc.serve.Data\")  # pylint: disable=protected-access\nclass Data(Object):  # pylint: disable=too-few-public-methods\n    \"\"\"The base class of multi-modality data (text, tokens, embedding, etc).\"\"\"\n\n    def __init__(self):  # pylint: disable=super-init-not-called\n        pass\n\n\n@tvm_ffi.register_object(\"mlc.serve.TextData\")  # pylint: disable=protected-access\nclass TextData(Data):\n    \"\"\"The class of text data, containing a text string.\n\n    Parameters\n    ----------\n    text : str\n        The text string.\n    \"\"\"\n\n    def __init__(self, text: str):\n        self.__init_handle_by_constructor__(_ffi_api.TextData, text)  # type: ignore  # pylint: disable=no-member\n\n    @property\n    def text(self) -> str:\n        \"\"\"The text data in `str`.\"\"\"\n        return str(_ffi_api.TextDataGetTextString(self))  # type: ignore  # pylint: disable=no-member\n\n    def __str__(self) -> str:\n        return self.text\n\n\n@tvm_ffi.register_object(\"mlc.serve.TokenData\")  # type: ignore  # pylint: disable=protected-access\nclass TokenData(Data):  # pylint: disable=too-few-public-methods\n    \"\"\"The class of token data, containing a list of token ids.\n\n    Parameters\n    ----------\n    token_ids : List[int]\n        The list of token ids.\n    \"\"\"\n\n    def __init__(self, token_ids: List[int]):\n        self.__init_handle_by_constructor__(_ffi_api.TokenData, *token_ids)  # type: ignore  # pylint: disable=no-member\n\n    @property\n    def token_ids(self) -> List[int]:\n        \"\"\"Return the token ids of the TokenData.\"\"\"\n        return list(_ffi_api.TokenDataGetTokenIds(self))  # type: ignore  # pylint: disable=no-member\n\n\n# mypy: disable-error-code=\"attr-defined\"\n@tvm_ffi.register_object(\"mlc.serve.ImageData\")  # type: ignore  # pylint: disable=protected-access\nclass ImageData(Data):\n    \"\"\"The class of image data, containing the image as Tensor.\n\n    Parameters\n    ----------\n    image : tvm.runtime.Tensor\n        The image data.\n    \"\"\"\n\n    def __init__(self, image: Tensor, embed_size: int):\n        self.embed_size = embed_size\n        self.__init_handle_by_constructor__(_ffi_api.ImageData, image, embed_size)  # type: ignore  # pylint: disable=no-member\n\n    @property\n    def image(self) -> Tensor:\n        \"\"\"Return the image data.\"\"\"\n        return _ffi_api.ImageDataGetImage(self)  # type: ignore  # pylint: disable=no-member\n\n    def __len__(self):\n        return self.embed_size\n\n    # pylint: disable=too-many-locals,unused-argument,unused-argument\n    @staticmethod\n    def from_url(url: str, config: Dict) -> \"ImageData\":\n        \"\"\"Get the image from the given URL, process and return the image tensor as TVM Tensor.\"\"\"\n\n        # pylint: disable=import-outside-toplevel, import-error\n        import base64\n        from io import BytesIO\n\n        import numpy as np\n        import requests\n        from PIL import Image\n\n        if url.startswith(\"data:image\"):\n            # The image is encoded in base64 format\n            base64_image = url.split(\",\")[1]\n            image_data = base64.b64decode(base64_image)\n            image_tensor = Image.open(BytesIO(image_data)).convert(\"RGB\")\n        elif url.startswith(\"http\"):\n            response = requests.get(url, timeout=5)\n            image_tensor = Image.open(BytesIO(response.content)).convert(\"RGB\")\n        else:\n            raise ValueError(f\"Unsupported image URL format: {url}\")\n\n        # image_embed_size = ImageData.get_embed_size(config)\n        # TODO: fix these hard-coded values for phi3.5-vision and llava # pylint: disable=fixme\n        image_embed_size = 576\n        if config[\"model_type\"] == \"phi3_v\":\n            image_embed_size = 1921\n        image_tensor = np.expand_dims(image_tensor, axis=0)  # HWC -> NHWC\n        image_features = tvm.runtime.tensor(image_tensor)\n        image_data = ImageData(image_features, image_embed_size)\n        return image_data\n\n    @staticmethod\n    def get_embed_size(config: Dict) -> int:\n        \"\"\"Get the image embedding size from the model config file.\"\"\"\n        image_size = config[\"model_config\"][\"vision_config\"][\"image_size\"]\n        patch_size = config[\"model_config\"][\"vision_config\"][\"patch_size\"]\n        embed_size = (image_size // patch_size) ** 2\n        return embed_size\n\n    @staticmethod\n    def get_input_size(config: Dict) -> int:\n        \"\"\"Get the image input size from the model config file.\"\"\"\n        image_size = config[\"model_config\"][\"vision_config\"][\"image_size\"]\n        return image_size\n\n\n@dataclass\nclass SingleRequestStreamOutput:\n    \"\"\"The request stream output of a single request.\n\n    Attributes\n    ----------\n    delta_token_ids : List[int]\n        The new generated tokens since the last callback invocation\n        for the input request.\n\n    delta_logprob_json_strs : Optional[List[str]]\n        The logprobs JSON strings of the new generated tokens\n        since last invocation.\n\n    finish_reason : Optional[str]\n        The finish reason of the request when it is finished,\n        of None if the request has not finished yet.\n    \"\"\"\n\n    delta_token_ids: List[int]\n    delta_logprob_json_strs: Optional[List[str]]\n    finish_reason: Optional[str]\n    request_final_usage_json_str: Optional[str]\n    extra_prefix_string: str\n\n\n@tvm_ffi.register_object(\"mlc.serve.RequestStreamOutput\")  # pylint: disable=protected-access\nclass RequestStreamOutput(Object):  # pylint: disable=too-few-public-methods\n    \"\"\"The generated delta request output that is streamed back\n    through callback stream function.\n    It contains four fields (in order):\n\n    request_id : str\n        The id of the request that the function is invoked for.\n\n    stream_outputs : List[SingleRequestStreamOutput]\n        The output instances, one for a request.\n\n    Note\n    ----\n    We do not provide constructor, since in practice only C++ side\n    instantiates this class.\n    \"\"\"\n\n    def unpack(self) -> Tuple[str, List[SingleRequestStreamOutput]]:\n        \"\"\"Return the fields of the delta output in a tuple.\n\n        Returns\n        -------\n        request_id : str\n            The id of the request that the function is invoked for.\n\n        stream_outputs : List[SingleRequestStreamOutput]\n            The output instances, one for a request.\n        \"\"\"\n        fields = _ffi_api.RequestStreamOutputUnpack(self)  # type: ignore  # pylint: disable=no-member\n        request_final_usage_json_str = fields[4]\n        request_id = str(fields[0])\n        if request_final_usage_json_str is not None:\n            return (\n                request_id,\n                [SingleRequestStreamOutput([], None, None, request_final_usage_json_str, \"\")],\n            )\n\n        stream_outputs = []\n        for i, (delta_token_ids, finish_reason, extra_prefix_string) in enumerate(\n            zip(fields[1], fields[3], fields[5])\n        ):\n            delta_logprob_json_strs = (\n                [str(logprob_json_str) for logprob_json_str in fields[2][i]]\n                if fields[2] is not None\n                else None\n            )\n            stream_outputs.append(\n                SingleRequestStreamOutput(\n                    delta_token_ids=list(delta_token_ids),\n                    delta_logprob_json_strs=delta_logprob_json_strs,\n                    finish_reason=str(finish_reason) if finish_reason is not None else None,\n                    request_final_usage_json_str=None,\n                    extra_prefix_string=str(extra_prefix_string),\n                )\n            )\n        return request_id, stream_outputs\n"
  },
  {
    "path": "python/mlc_llm/serve/embedding_engine.py",
    "content": "\"\"\"Asynchronous embedding inference engine for encoder and decoder models.\"\"\"\n\nimport asyncio\nimport concurrent.futures\nimport json\nimport os\nfrom typing import List, Literal, Optional, Tuple, Union\n\nimport numpy as np\nimport tvm\nfrom tvm import relax\nfrom tvm.runtime import Device, ShapeTuple\n\nfrom mlc_llm.serve import engine_utils\nfrom mlc_llm.support.auto_device import detect_device\nfrom mlc_llm.tokenizers import Tokenizer\n\n\nclass AsyncEmbeddingEngine:  # pylint: disable=too-many-instance-attributes\n    \"\"\"Asynchronous embedding inference engine.\n\n    Supports both encoder models (BERT-style) and decoder-only embedding models\n    (e.g. Qwen3-Embeddings). Uses a ThreadPoolExecutor for background inference\n    so that the asyncio event loop is not blocked.\n\n    Parameters\n    ----------\n    model : str\n        Path to the model weight directory.\n\n    model_lib : str\n        Path to the compiled model library (.so/.dylib file).\n\n    device : Union[str, Device]\n        Device string, e.g. \"auto\", \"cuda:0\", \"metal\".\n\n    pooling_strategy : Optional[str]\n        Pooling strategy: \"cls\" (first token), \"mean\" (masked average),\n        or \"last\" (last token). If None, auto-detected based on model type:\n        encoder -> \"cls\", decoder -> \"last\".\n    \"\"\"\n\n    def __init__(  # pylint: disable=too-many-branches\n        self,\n        model: str,\n        model_lib: str,\n        device: Union[str, Device] = \"auto\",\n        *,\n        pooling_strategy: Optional[str] = None,\n    ) -> None:\n        # Reuse existing utility: device detection\n        self.device = detect_device(device) if isinstance(device, str) else device\n        # Reuse existing utility: tokenizer\n        self.tokenizer = Tokenizer(model)\n\n        # Load TVM module, metadata, and params via engine_utils helpers\n        ex = tvm.runtime.load_module(model_lib)\n        vm = relax.VirtualMachine(ex, device=self.device)\n        self._mod = vm.module\n        self._metadata = json.loads(self._mod[\"_metadata\"]())\n        self._params = engine_utils.load_embedding_params(model, self.device, self._metadata)\n\n        # Detect model type and set pooling strategy\n        self.embedding_metadata = engine_utils.get_embedding_metadata(self._metadata)\n        if self.embedding_metadata:\n            self.model_type = self.embedding_metadata[\"model_type\"]\n            self.pooling_strategy = self.embedding_metadata[\"pooling_strategy\"]\n            self.normalize = self.embedding_metadata[\"normalize\"]\n        else:\n            self.model_type = engine_utils.detect_embedding_model_type(self._mod)\n            self.pooling_strategy = \"cls\" if self.model_type == \"encoder\" else \"last\"\n            self.normalize = True\n        # Allow caller to override pooling strategy\n        if pooling_strategy:\n            self.pooling_strategy = pooling_strategy\n\n        # Initialize model-type-specific functions\n        if self.model_type == \"encoder\":\n            self._init_encoder(model)\n        else:\n            self._init_decoder(model)\n\n        # Background thread pool (1 worker = serialized GPU inference)\n        self._executor = concurrent.futures.ThreadPoolExecutor(\n            max_workers=1, thread_name_prefix=\"embedding\"\n        )\n        self._terminated = False\n\n    def _init_encoder(self, model: str) -> None:\n        \"\"\"Initialize encoder (BERT-style) model functions and special tokens.\"\"\"\n        self._prefill_func = self._mod[\"prefill\"]\n        self._cls_token_id: Optional[int] = None\n        self._sep_token_id: Optional[int] = None\n        tok_config_path = os.path.join(model, \"tokenizer_config.json\")\n        if os.path.exists(tok_config_path):\n            with open(tok_config_path, encoding=\"utf-8\") as f:\n                tok_config = json.load(f)\n            # Try added_tokens_decoder first (newer HF format)\n            added = tok_config.get(\"added_tokens_decoder\", {})\n            for tid, info in added.items():\n                if info.get(\"content\") == tok_config.get(\"cls_token\"):\n                    self._cls_token_id = int(tid)\n                if info.get(\"content\") == tok_config.get(\"sep_token\"):\n                    self._sep_token_id = int(tid)\n            # Fallback: encode the special token strings via tokenizer\n            if self._cls_token_id is None and tok_config.get(\"cls_token\"):\n                ids = list(self.tokenizer.encode(tok_config[\"cls_token\"]))\n                if len(ids) == 1:\n                    self._cls_token_id = ids[0]\n            if self._sep_token_id is None and tok_config.get(\"sep_token\"):\n                ids = list(self.tokenizer.encode(tok_config[\"sep_token\"]))\n                if len(ids) == 1:\n                    self._sep_token_id = ids[0]\n\n    def _init_decoder(self, model: str) -> None:\n        \"\"\"Initialize decoder (Qwen3-Embeddings style) model functions.\"\"\"\n        # Prefer tokenizer post-processing (HF-style) for terminal/pooling token handling.\n        # Only fall back to manual EOS append when tokenizer does not define a post-processor\n        # that actually appends a token at the end of the sequence.\n        self._decoder_tokenizer_appends_eos = False\n        tokenizer_json_path = os.path.join(model, \"tokenizer.json\")\n        if os.path.exists(tokenizer_json_path):\n            with open(tokenizer_json_path, encoding=\"utf-8\") as f:\n                tokenizer_json = json.load(f)\n            post_proc = tokenizer_json.get(\"post_processor\")\n            if post_proc is not None:\n                # Check if the post-processor actually appends a special token at the end\n                # (e.g. TemplateProcessing with \"$A <|endoftext|>\"). We verify by encoding\n                # a test string and checking if the last token is a known special token.\n                test_tokens = list(self.tokenizer.encode(\"test\"))\n                if len(test_tokens) > 0:\n                    vocab = tokenizer_json.get(\"added_tokens\", [])\n                    special_ids = {t[\"id\"] for t in vocab if t.get(\"special\", False)}\n                    if test_tokens[-1] in special_ids:\n                        self._decoder_tokenizer_appends_eos = True\n\n        # Read EOS token from config — fallback only when tokenizer does not auto-append.\n        self._decoder_eos_token_id: Optional[int] = None\n        config_path = os.path.join(model, \"mlc-chat-config.json\")\n        if os.path.exists(config_path):\n            with open(config_path, encoding=\"utf-8\") as f:\n                chat_config = json.load(f)\n            eos = chat_config.get(\"eos_token_id\")\n            if isinstance(eos, list):\n                self._decoder_eos_token_id = eos[0]\n            elif isinstance(eos, int):\n                self._decoder_eos_token_id = eos\n\n        self._embed_func = self._mod[\"embed\"]\n        self._prefill_to_hidden_func = self._mod[\"prefill_to_last_hidden_states\"]\n        self._batch_prefill_to_hidden_func = self._mod[\"batch_prefill_to_last_hidden_states\"]\n        if self._mod.implements_function(\"create_tir_paged_kv_cache\"):\n            self._create_kv_cache_func = self._mod[\"create_tir_paged_kv_cache\"]\n        elif self._mod.implements_function(\"create_flashinfer_paged_kv_cache\"):\n            self._create_kv_cache_func = self._mod[\"create_flashinfer_paged_kv_cache\"]\n        else:\n            raise RuntimeError(\"Cannot find KV cache creation function in model library.\")\n        self._kv_state_add_sequence = tvm.get_global_func(\"vm.builtin.kv_state_add_sequence\")\n        self._kv_state_remove_sequence = tvm.get_global_func(\"vm.builtin.kv_state_remove_sequence\")\n        self._kv_state_begin_forward = tvm.get_global_func(\"vm.builtin.kv_state_begin_forward\")\n        self._kv_state_end_forward = tvm.get_global_func(\"vm.builtin.kv_state_end_forward\")\n        self._nd_reshape = tvm.get_global_func(\"vm.builtin.reshape\")\n\n    def embed(self, inputs: List[str]) -> Tuple[List[List[float]], int]:\n        \"\"\"Compute embeddings for a list of input strings (synchronous).\n\n        Parameters\n        ----------\n        inputs : List[str]\n            The input strings to embed.\n\n        Returns\n        -------\n        embeddings : List[List[float]]\n            The L2-normalized embedding vectors.\n        total_tokens : int\n            Total number of tokens processed.\n        \"\"\"\n        if self.model_type == \"encoder\":\n            return self._embed_encoder(inputs)\n        return self._embed_decoder(inputs)\n\n    async def async_embed(self, inputs: List[str]) -> Tuple[List[List[float]], int]:\n        \"\"\"Compute embeddings asynchronously in a background thread.\n\n        This method does not block the asyncio event loop.\n\n        Parameters\n        ----------\n        inputs : List[str]\n            The input strings to embed.\n\n        Returns\n        -------\n        embeddings : List[List[float]]\n            The L2-normalized embedding vectors.\n        total_tokens : int\n            Total number of tokens processed.\n        \"\"\"\n        loop = asyncio.get_running_loop()\n        return await loop.run_in_executor(self._executor, self.embed, inputs)\n\n    def _embed_encoder(  # pylint: disable=too-many-locals\n        self, inputs: List[str]\n    ) -> Tuple[List[List[float]], int]:\n        \"\"\"Encoder model embedding (BERT-style).\n\n        Processes each input individually to avoid batch padding artifacts.\n\n        Encoder uses bidirectional attention, so chunked prefill is NOT possible\n        (each token must attend to all other tokens in the full sequence).\n        Inputs exceeding prefill_chunk_size are truncated.\n\n        (Additional Strategy)\n        TODO: For better long-text support, implement sliding window + mean pooling:\n          1. Split text into overlapping windows of prefill_chunk_size (stride=chunk/2)\n          2. Encode each window independently\n          3. Mean-pool all window embeddings → final embedding → L2 normalize\n          This preserves information from the full text at the cost of N× compute.\n        \"\"\"\n        embeddings: List[List[float]] = []\n        total_tokens = 0\n        prefill_chunk = self._metadata.get(\"prefill_chunk_size\", 512)\n\n        for text in inputs:\n            tokens = list(self.tokenizer.encode(text))\n            # Add [CLS] and [SEP] if needed\n            if self._cls_token_id is not None and (\n                len(tokens) == 0 or tokens[0] != self._cls_token_id\n            ):\n                tokens = [self._cls_token_id] + tokens\n            if self._sep_token_id is not None and (\n                len(tokens) == 0 or tokens[-1] != self._sep_token_id\n            ):\n                tokens = tokens + [self._sep_token_id]\n\n            # Truncate to compiled buffer limit (keep [CLS] at start, [SEP] at end)\n            if len(tokens) > prefill_chunk:\n                tokens = tokens[:prefill_chunk]\n                if self._sep_token_id is not None:\n                    tokens[-1] = self._sep_token_id\n\n            seq_len = len(tokens)\n            total_tokens += seq_len\n\n            token_ids = np.array([tokens], dtype=np.int32)  # [1, seq_len]\n            attention_mask: np.ndarray = np.ones((1, seq_len), dtype=np.int32)  # [1, seq_len]\n\n            tokens_tvm = tvm.runtime.tensor(token_ids, device=self.device)\n            mask_tvm = tvm.runtime.tensor(attention_mask, device=self.device)\n\n            output = self._prefill_func(tokens_tvm, mask_tvm, self._params)\n            # .numpy() copies to CPU, escaping TVM workspace buffer reuse across calls.\n            output_np = output.numpy()  # [1, seq_len, hidden_size]\n\n            # Pooling\n            if self.pooling_strategy == \"cls\":\n                pooled = output_np[0, 0, :]\n            elif self.pooling_strategy == \"mean\":\n                pooled = output_np[0].mean(axis=0)\n            else:  # \"last\"\n                pooled = output_np[0, -1, :]\n\n            # L2 normalize\n            pooled = pooled.astype(np.float32)\n            if self.normalize:\n                norm = np.linalg.norm(pooled)\n                if norm > 1e-12:\n                    pooled = pooled / norm\n\n            embeddings.append(pooled.tolist())\n\n        return embeddings, total_tokens\n\n    def _embed_decoder(self, inputs: List[str]) -> Tuple[List[List[float]], int]:\n        \"\"\"Decoder model embedding with batch prefill optimization.\n\n        When total tokens fit within prefill_chunk_size, all inputs are processed\n        in a single batch forward pass using shared KV cache. Otherwise, falls back\n        to sequential chunked prefill per input.\n        \"\"\"\n        # Read KV cache config from metadata\n        prefill_chunk = self._metadata.get(\"prefill_chunk_size\", 2048)\n        max_seq_len = self._metadata.get(\"context_window_size\", 32768)\n        if max_seq_len == -1:\n            max_seq_len = self._metadata.get(\"sliding_window_size\", -1)\n        assert max_seq_len > 0, f\"max_seq_len must be positive, got {max_seq_len}\"\n        support_sliding = int(self._metadata.get(\"sliding_window_size\", -1) != -1)\n\n        # Tokenize all inputs. Prefer tokenizer post-processor output. If absent (older models),\n        # fall back to appending eos_token_id when missing.\n        token_lists: List[List[int]] = []\n        for text in inputs:\n            tokens = list(self.tokenizer.encode(text))\n            if (\n                not self._decoder_tokenizer_appends_eos\n                and self._decoder_eos_token_id is not None\n                and (len(tokens) == 0 or tokens[-1] != self._decoder_eos_token_id)\n            ):\n                tokens.append(self._decoder_eos_token_id)\n            if len(tokens) > max_seq_len:\n                tokens = tokens[:max_seq_len]\n            token_lists.append(tokens)\n\n        total_tokens = sum(len(t) for t in token_lists)\n\n        # Fast path: all tokens fit in one prefill chunk → batch forward\n        if total_tokens <= prefill_chunk and all(len(t) > 0 for t in token_lists):\n            return self._batch_embed_decoder(\n                token_lists, total_tokens, max_seq_len, prefill_chunk, support_sliding\n            )\n\n        # Greedy sub-batching: pack texts into sub-batches that fit within\n        # prefill_chunk, preserving input order. Oversize texts (single text\n        # exceeding prefill_chunk) fall back to sequential chunked prefill.\n        sub_batches = self._build_sub_batches(token_lists, prefill_chunk)\n        all_embeddings: List[List[float]] = []\n        for batch_type, batch, batch_total in sub_batches:\n            if batch_type == \"batch\":\n                embs, _ = self._batch_embed_decoder(\n                    batch, batch_total, max_seq_len, prefill_chunk, support_sliding\n                )\n            else:\n                embs, _ = self._sequential_embed_decoder(\n                    batch, batch_total, max_seq_len, prefill_chunk, support_sliding\n                )\n            all_embeddings.extend(embs)\n\n        return all_embeddings, total_tokens\n\n    @staticmethod\n    def _build_sub_batches(\n        token_lists: List[List[int]], prefill_chunk: int\n    ) -> List[Tuple[Literal[\"batch\", \"sequential\"], List[List[int]], int]]:\n        \"\"\"Partition token lists into sub-batches that fit within prefill_chunk.\n\n        Each sub-batch is a tuple of (mode, token_lists, total_token_count).\n        Empty token lists are skipped to avoid invalid batch processing.\n        \"\"\"\n        sub_batches: List[Tuple[Literal[\"batch\", \"sequential\"], List[List[int]], int]] = []\n        current_batch: List[List[int]] = []\n        current_tokens = 0\n\n        for tokens in token_lists:\n            if not tokens:\n                continue\n            token_len = len(tokens)\n            is_oversized = token_len > prefill_chunk\n            if current_batch and (is_oversized or current_tokens + token_len > prefill_chunk):\n                sub_batches.append((\"batch\", current_batch, current_tokens))\n                current_batch, current_tokens = [], 0\n            if is_oversized:\n                sub_batches.append((\"sequential\", [tokens], token_len))\n            else:\n                current_batch.append(tokens)\n                current_tokens += token_len\n        if current_batch:\n            sub_batches.append((\"batch\", current_batch, current_tokens))\n\n        return sub_batches\n\n    def _batch_embed_decoder(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        token_lists: List[List[int]],\n        total_tokens: int,\n        max_seq_len: int,\n        prefill_chunk: int,\n        support_sliding: int,\n    ) -> Tuple[List[List[float]], int]:\n        \"\"\"Batch prefill: process all inputs in a single forward pass.\"\"\"\n        batch_size = len(token_lists)\n\n        # Create KV cache for the entire batch\n        kv_cache = self._create_kv_cache_func(\n            ShapeTuple([batch_size]),\n            ShapeTuple([max_seq_len]),\n            ShapeTuple([prefill_chunk]),\n            ShapeTuple([16]),\n            ShapeTuple([support_sliding]),\n        )\n\n        # Register all sequences\n        seq_ids = list(range(batch_size))\n        seq_lens = [len(t) for t in token_lists]\n        for sid in seq_ids:\n            self._kv_state_add_sequence(kv_cache, sid)\n\n        # Begin forward with all sequences at once\n        self._kv_state_begin_forward(kv_cache, ShapeTuple(seq_ids), ShapeTuple(seq_lens))\n\n        # Concatenate all tokens → embed → batch prefill\n        all_tokens = []\n        for tokens in token_lists:\n            all_tokens.extend(tokens)\n        token_ids = tvm.runtime.tensor(np.array(all_tokens, dtype=np.int32), device=self.device)\n        all_embed = self._embed_func(token_ids, self._params)\n        all_embed = self._nd_reshape(all_embed, ShapeTuple([1, total_tokens, all_embed.shape[-1]]))\n\n        hidden_states, _ = self._batch_prefill_to_hidden_func(all_embed, kv_cache, self._params)\n        # .numpy() copies to CPU, escaping TVM workspace buffer reuse across calls.\n        # (torch.from_dlpack is zero-copy and hits aliasing bugs on 2nd+ invocation.)\n        hidden_np = hidden_states.numpy()\n        self._kv_state_end_forward(kv_cache)\n        for sid in seq_ids:\n            self._kv_state_remove_sequence(kv_cache, sid)\n\n        # Extract last token hidden state per sequence\n        embeddings: List[List[float]] = []\n        offset = 0\n        for tokens in token_lists:\n            last_pos = offset + len(tokens) - 1\n            pooled = hidden_np[0, last_pos, :].astype(np.float32)\n            if self.normalize:\n                norm = np.linalg.norm(pooled)\n                if norm > 1e-12:\n                    pooled = pooled / norm\n            embeddings.append(pooled.tolist())\n            offset += len(tokens)\n\n        return embeddings, total_tokens\n\n    def _sequential_embed_decoder(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        token_lists: List[List[int]],\n        total_tokens: int,\n        max_seq_len: int,\n        prefill_chunk: int,\n        support_sliding: int,\n    ) -> Tuple[List[List[float]], int]:\n        \"\"\"Sequential chunked prefill: process each input independently.\"\"\"\n        embeddings: List[List[float]] = []\n\n        for tokens in token_lists:\n            if len(tokens) == 0:\n                continue\n\n            # Create KV cache for this single sequence\n            kv_cache = self._create_kv_cache_func(\n                ShapeTuple([1]),\n                ShapeTuple([max_seq_len]),\n                ShapeTuple([prefill_chunk]),\n                ShapeTuple([16]),\n                ShapeTuple([support_sliding]),\n            )\n            self._kv_state_add_sequence(kv_cache, 0)\n\n            # Process tokens in chunks\n            hidden = None\n            for chunk_start in range(0, len(tokens), prefill_chunk):\n                chunk_end = min(chunk_start + prefill_chunk, len(tokens))\n                chunk_tokens = tokens[chunk_start:chunk_end]\n                chunk_len = len(chunk_tokens)\n\n                token_ids = tvm.runtime.tensor(\n                    np.array(chunk_tokens, dtype=np.int32), device=self.device\n                )\n                chunk_embed = self._embed_func(token_ids, self._params)\n                chunk_embed = self._nd_reshape(\n                    chunk_embed, ShapeTuple([1, chunk_len, chunk_embed.shape[-1]])\n                )\n                self._kv_state_begin_forward(kv_cache, ShapeTuple([0]), ShapeTuple([chunk_len]))\n                hidden, kv_cache = self._prefill_to_hidden_func(chunk_embed, kv_cache, self._params)\n                # .numpy() copies to CPU, escaping TVM buffer aliasing.\n                hidden_np = hidden.numpy()\n                self._kv_state_end_forward(kv_cache)\n\n            self._kv_state_remove_sequence(kv_cache, 0)\n\n            pooled = hidden_np[0, -1, :] if hidden_np.ndim == 3 else hidden_np[-1, :]\n            pooled = pooled.astype(np.float32)\n            if self.normalize:\n                norm = np.linalg.norm(pooled)\n                if norm > 1e-12:\n                    pooled = pooled / norm\n            embeddings.append(pooled.tolist())\n\n        return embeddings, total_tokens\n\n    def terminate(self) -> None:\n        \"\"\"Terminate the engine and clean up the thread pool.\"\"\"\n        if getattr(self, \"_terminated\", True):\n            return\n        self._terminated = True\n        self._executor.shutdown(wait=False)\n\n    def __del__(self):\n        self.terminate()\n"
  },
  {
    "path": "python/mlc_llm/serve/engine.py",
    "content": "\"\"\"The MLC LLM Serving Engine.\"\"\"\n\n# pylint: disable=too-many-lines\n\nimport asyncio\nimport queue\nimport sys\nimport weakref\nfrom typing import (\n    Any,\n    AsyncGenerator,\n    Dict,\n    Iterator,\n    List,\n    Literal,\n    Optional,\n    Tuple,\n    Union,\n    overload,\n)\n\nfrom tvm.runtime import Device\n\nfrom mlc_llm.protocol import debug_protocol, openai_api_protocol\nfrom mlc_llm.protocol.generation_config import GenerationConfig\nfrom mlc_llm.serve import data, engine_utils\nfrom mlc_llm.serve.config import EngineConfig\nfrom mlc_llm.support import logging\nfrom mlc_llm.tokenizers import TextStreamer\n\nfrom . import engine_base\n\nlogger = logging.getLogger(__name__)\n\n\n# Note: we define both AsyncChat and Chat for Python type analysis.\nclass AsyncChat:  # pylint: disable=too-few-public-methods\n    \"\"\"The proxy class to direct to async chat completions.\"\"\"\n\n    def __init__(self, engine: weakref.ReferenceType) -> None:\n        assert isinstance(engine(), AsyncMLCEngine)\n        self.completions = AsyncChatCompletion(engine)\n\n\nclass Chat:  # pylint: disable=too-few-public-methods\n    \"\"\"The proxy class to direct to chat completions.\"\"\"\n\n    def __init__(self, engine: weakref.ReferenceType) -> None:\n        assert isinstance(engine(), MLCEngine)\n        self.completions = ChatCompletion(engine)\n\n\nclass AsyncChatCompletion:  # pylint: disable=too-few-public-methods\n    \"\"\"The proxy class to direct to async chat completions.\"\"\"\n\n    if sys.version_info >= (3, 9):\n        engine: weakref.ReferenceType[\"AsyncMLCEngine\"]\n    else:\n        engine: weakref.ReferenceType\n\n    def __init__(self, engine: weakref.ReferenceType) -> None:\n        self.engine = engine\n\n    @overload\n    async def create(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        *,\n        messages: List[Dict[str, Any]],\n        stream: Literal[True],\n        model: Optional[str] = None,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        logprobs: bool = False,\n        top_logprobs: int = 0,\n        logit_bias: Optional[Dict[int, float]] = None,\n        max_tokens: Optional[int] = None,\n        n: int = 1,\n        seed: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream_options: Optional[Dict[str, Any]] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        tools: Optional[List[Dict[str, Any]]] = None,\n        tool_choice: Optional[Union[Literal[\"none\", \"auto\"], Dict]] = None,\n        user: Optional[str] = None,\n        response_format: Optional[Dict[str, Any]] = None,\n        request_id: Optional[str] = None,\n        extra_body: Optional[Dict[str, Any]] = None,\n    ) -> AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any]:\n        \"\"\"Asynchronous streaming chat completion interface with OpenAI API compatibility.\n        The method is a coroutine that streams ChatCompletionStreamResponse\n        that conforms to OpenAI API one at a time via yield.\n\n        See https://platform.openai.com/docs/api-reference/chat/create for specification.\n\n        Parameters\n        ----------\n        request_id : Optional[str]\n            The optional request id.\n            A random one will be generated if it is not given.\n\n        extra_body: Optional[Dict[str, Any]] = None,\n            Extra body options to pass to the request.\n            Can be used to pass debug config as extra_body[\"debug_config\"]\n\n        Yields\n        ------\n        stream_response : ChatCompletionStreamResponse\n            The stream response conforming to OpenAI API.\n            See mlc_llm/protocol/openai_api_protocol.py or\n            https://platform.openai.com/docs/api-reference/chat/streaming for specification.\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n\n    @overload\n    async def create(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        *,\n        messages: List[Dict[str, Any]],\n        model: Optional[str] = None,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        logprobs: bool = False,\n        top_logprobs: int = 0,\n        logit_bias: Optional[Dict[int, float]] = None,\n        max_tokens: Optional[int] = None,\n        n: int = 1,\n        seed: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream: Literal[False] = False,\n        stream_options: Literal[None] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        tools: Optional[List[Dict[str, Any]]] = None,\n        tool_choice: Optional[Union[Literal[\"none\", \"auto\"], Dict]] = None,\n        user: Optional[str] = None,\n        response_format: Optional[Dict[str, Any]] = None,\n        request_id: Optional[str] = None,\n        extra_body: Optional[Dict[str, Any]] = None,\n    ) -> openai_api_protocol.ChatCompletionResponse:\n        \"\"\"Asynchronous non-streaming chat completion interface with OpenAI API compatibility.\n        The method is a coroutine that streams ChatCompletionStreamResponse\n        that conforms to OpenAI API one at a time via yield.\n\n        See https://platform.openai.com/docs/api-reference/chat/create for specification.\n\n        Parameters\n        ----------\n        request_id : Optional[str]\n            The optional request id.\n            A random one will be generated if it is not given.\n\n        extra_body: Optional[Dict[str, Any]] = None,\n            Extra body options to pass to the request.\n            Can be used to pass debug config as extra_body[\"debug_config\"]\n\n        Returns\n        -------\n        response : ChatCompletionResponse\n            The chat completion response conforming to OpenAI API.\n            See mlc_llm/protocol/openai_api_protocol.py or\n            https://platform.openai.com/docs/api-reference/chat/object for specification.\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n\n    async def create(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        *,\n        messages: List[Dict[str, Any]],\n        model: Optional[str] = None,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        logprobs: bool = False,\n        top_logprobs: int = 0,\n        logit_bias: Optional[Dict[int, float]] = None,\n        max_tokens: Optional[int] = None,\n        n: int = 1,\n        seed: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream: bool = False,\n        stream_options: Optional[Dict[str, Any]] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        tools: Optional[List[Dict[str, Any]]] = None,\n        tool_choice: Optional[Union[Literal[\"none\", \"auto\"], Dict]] = None,\n        user: Optional[str] = None,\n        response_format: Optional[Dict[str, Any]] = None,\n        request_id: Optional[str] = None,\n        extra_body: Optional[Dict[str, Any]] = None,\n    ) -> Union[\n        AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any],\n        openai_api_protocol.ChatCompletionResponse,\n    ]:\n        \"\"\"Asynchronous chat completion interface with OpenAI API compatibility.\n\n        See https://platform.openai.com/docs/api-reference/chat/create for specification.\n\n        Parameters\n        ----------\n        request_id : Optional[str]\n            The optional request id.\n            A random one will be generated if it is not given.\n\n        extra_body: Optional[Dict[str, Any]] = None,\n            Extra body options to pass to the request.\n            Can be used to pass debug config as extra_body[\"debug_config\"]\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n        return await self.engine()._chat_completion(  # pylint: disable=protected-access\n            messages=messages,\n            model=model,\n            frequency_penalty=frequency_penalty,\n            presence_penalty=presence_penalty,\n            logprobs=logprobs,\n            top_logprobs=top_logprobs,\n            logit_bias=logit_bias,\n            max_tokens=max_tokens,\n            n=n,\n            seed=seed,\n            stop=stop,\n            stream=stream,\n            stream_options=(\n                openai_api_protocol.StreamOptions.model_validate(stream_options)\n                if stream_options is not None\n                else None\n            ),\n            temperature=temperature,\n            top_p=top_p,\n            tools=tools,\n            tool_choice=tool_choice,\n            user=user,\n            response_format=response_format,\n            request_id=request_id,\n            debug_config=(extra_body.get(\"debug_config\", None) if extra_body is not None else None),\n        )\n\n\nclass ChatCompletion:  # pylint: disable=too-few-public-methods\n    \"\"\"The proxy class to direct to chat completions.\"\"\"\n\n    if sys.version_info >= (3, 9):\n        engine: weakref.ReferenceType[\"MLCEngine\"]\n    else:\n        engine: weakref.ReferenceType\n\n    def __init__(self, engine: weakref.ReferenceType) -> None:\n        self.engine = engine\n\n    @overload\n    def create(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        *,\n        messages: List[Dict[str, Any]],\n        stream: Literal[True],\n        model: Optional[str] = None,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        logprobs: bool = False,\n        top_logprobs: int = 0,\n        logit_bias: Optional[Dict[int, float]] = None,\n        max_tokens: Optional[int] = None,\n        n: int = 1,\n        seed: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream_options: Optional[Dict[str, Any]] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        tools: Optional[List[Dict[str, Any]]] = None,\n        tool_choice: Optional[Union[Literal[\"none\", \"auto\"], Dict]] = None,\n        user: Optional[str] = None,\n        response_format: Optional[Dict[str, Any]] = None,\n        request_id: Optional[str] = None,\n        extra_body: Optional[Dict[str, Any]] = None,\n    ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]:\n        \"\"\"Synchronous streaming chat completion interface with OpenAI API compatibility.\n        The method streams back ChatCompletionStreamResponse that conforms to\n        OpenAI API one at a time via yield.\n\n        See https://platform.openai.com/docs/api-reference/chat/create for specification.\n\n        Parameters\n        ----------\n        request_id : Optional[str]\n            The optional request id.\n            A random one will be generated if it is not given.\n\n        extra_body: Optional[Dict[str, Any]] = None,\n            Extra body options to pass to the request.\n            Can be used to pass debug config as extra_body[\"debug_config\"]\n\n        Yields\n        ------\n        stream_response : ChatCompletionStreamResponse\n            The stream response conforming to OpenAI API.\n            See mlc_llm/protocol/openai_api_protocol.py or\n            https://platform.openai.com/docs/api-reference/chat/streaming for specification.\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n\n    @overload\n    def create(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        *,\n        messages: List[Dict[str, Any]],\n        model: Optional[str] = None,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        logprobs: bool = False,\n        top_logprobs: int = 0,\n        logit_bias: Optional[Dict[int, float]] = None,\n        max_tokens: Optional[int] = None,\n        n: int = 1,\n        seed: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream: Literal[False] = False,\n        stream_options: Literal[None] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        tools: Optional[List[Dict[str, Any]]] = None,\n        tool_choice: Optional[Union[Literal[\"none\", \"auto\"], Dict]] = None,\n        user: Optional[str] = None,\n        response_format: Optional[Dict[str, Any]] = None,\n        request_id: Optional[str] = None,\n        extra_body: Optional[Dict[str, Any]] = None,\n    ) -> openai_api_protocol.ChatCompletionResponse:\n        \"\"\"Synchronous non-streaming chat completion interface with OpenAI API compatibility.\n\n        See https://platform.openai.com/docs/api-reference/chat/create for specification.\n\n        Parameters\n        ----------\n        request_id : Optional[str]\n            The optional request id.\n            A random one will be generated if it is not given.\n\n        extra_body: Optional[Dict[str, Any]] = None,\n            Extra body options to pass to the request.\n            Can be used to pass debug config as extra_body[\"debug_config\"]\n\n        Returns\n        ------\n        response : ChatCompletionResponse\n            The chat completion response conforming to OpenAI API.\n            See mlc_llm/protocol/openai_api_protocol.py or\n            https://platform.openai.com/docs/api-reference/chat/object for specification.\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n\n    def create(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        *,\n        messages: List[Dict[str, Any]],\n        model: Optional[str] = None,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        logprobs: bool = False,\n        top_logprobs: int = 0,\n        logit_bias: Optional[Dict[int, float]] = None,\n        max_tokens: Optional[int] = None,\n        n: int = 1,\n        seed: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream: bool = False,\n        stream_options: Optional[Dict[str, Any]] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        tools: Optional[List[Dict[str, Any]]] = None,\n        tool_choice: Optional[Union[Literal[\"none\", \"auto\"], Dict]] = None,\n        user: Optional[str] = None,\n        response_format: Optional[Dict[str, Any]] = None,\n        request_id: Optional[str] = None,\n        extra_body: Optional[Dict[str, Any]] = None,\n    ) -> Union[\n        Iterator[openai_api_protocol.ChatCompletionStreamResponse],\n        openai_api_protocol.ChatCompletionResponse,\n    ]:\n        \"\"\"Synchronous chat completion interface with OpenAI API compatibility.\n\n        See https://platform.openai.com/docs/api-reference/chat/create for specification.\n\n        Parameters\n        ----------\n        request_id : Optional[str]\n            The optional request id.\n            A random one will be generated if it is not given.\n\n        extra_body: Optional[Dict[str, Any]] = None,\n            Extra body options to pass to the request.\n            Can be used to pass debug config as extra_body[\"debug_config\"]\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n        return self.engine()._chat_completion(  # pylint: disable=protected-access\n            messages=messages,\n            model=model,\n            frequency_penalty=frequency_penalty,\n            presence_penalty=presence_penalty,\n            logprobs=logprobs,\n            top_logprobs=top_logprobs,\n            logit_bias=logit_bias,\n            max_tokens=max_tokens,\n            n=n,\n            seed=seed,\n            stop=stop,\n            stream=stream,\n            stream_options=(\n                openai_api_protocol.StreamOptions.model_validate(stream_options)\n                if stream_options is not None\n                else None\n            ),\n            temperature=temperature,\n            top_p=top_p,\n            tools=tools,\n            tool_choice=tool_choice,\n            user=user,\n            response_format=response_format,\n            request_id=request_id,\n            debug_config=(extra_body.get(\"debug_config\", None) if extra_body is not None else None),\n        )\n\n\nclass AsyncCompletion:  # pylint: disable=too-few-public-methods\n    \"\"\"The proxy class to direct to async completions.\"\"\"\n\n    if sys.version_info >= (3, 9):\n        engine: weakref.ReferenceType[\"AsyncMLCEngine\"]\n    else:\n        engine: weakref.ReferenceType\n\n    def __init__(self, engine: weakref.ReferenceType) -> None:\n        self.engine = engine\n\n    @overload\n    async def create(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        *,\n        prompt: Union[str, List[int]],\n        stream: Literal[True],\n        model: Optional[str] = None,\n        best_of: int = 1,\n        echo: bool = False,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        logprobs: Optional[int] = None,\n        logit_bias: Optional[Dict[int, float]] = None,\n        max_tokens: Optional[int] = None,\n        n: int = 1,\n        seed: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream_options: Optional[Dict[str, Any]] = None,\n        suffix: Optional[str] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        user: Optional[str] = None,\n        response_format: Optional[Dict[str, Any]] = None,\n        request_id: Optional[str] = None,\n        extra_body: Optional[Dict[str, Any]] = None,\n    ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]:\n        \"\"\"Asynchronous streaming completion interface with OpenAI API compatibility.\n        The method is a coroutine that streams CompletionResponse\n        that conforms to OpenAI API one at a time via yield.\n\n        See https://platform.openai.com/docs/api-reference/completions/create for specification.\n\n        Parameters\n        ----------\n        request_id : Optional[str]\n            The optional request id.\n            A random one will be generated if it is not given.\n\n        extra_body: Optional[Dict[str, Any]] = None,\n            Extra body options to pass to the request.\n            Can be used to pass debug config as extra_body[\"debug_config\"]\n\n        Yields\n        ------\n        stream_response : CompletionResponse\n            The stream response conforming to OpenAI API.\n            See mlc_llm/protocol/openai_api_protocol.py or\n            https://platform.openai.com/docs/api-reference/completions/object for specification.\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n\n    @overload\n    async def create(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        *,\n        prompt: Union[str, List[int]],\n        model: Optional[str] = None,\n        best_of: int = 1,\n        echo: bool = False,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        logprobs: Optional[int] = None,\n        logit_bias: Optional[Dict[int, float]] = None,\n        max_tokens: Optional[int] = None,\n        n: int = 1,\n        seed: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream: Literal[False] = False,\n        stream_options: Literal[None] = None,\n        suffix: Optional[str] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        user: Optional[str] = None,\n        response_format: Optional[Dict[str, Any]] = None,\n        request_id: Optional[str] = None,\n        extra_body: Optional[Dict[str, Any]] = None,\n    ) -> openai_api_protocol.CompletionResponse:\n        \"\"\"Asynchronous non-streaming completion interface with OpenAI API compatibility.\n\n        See https://platform.openai.com/docs/api-reference/completions/create for specification.\n\n        Parameters\n        ----------\n        request_id : Optional[str]\n            The optional request id.\n            A random one will be generated if it is not given.\n\n        extra_body: Optional[Dict[str, Any]] = None,\n            Extra body options to pass to the request.\n            Can be used to pass debug config as extra_body[\"debug_config\"]\n\n        Returns\n        ------\n        response : CompletionResponse\n            The completion response conforming to OpenAI API.\n            See mlc_llm/protocol/openai_api_protocol.py or\n            https://platform.openai.com/docs/api-reference/completions/object for specification.\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n\n    async def create(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        *,\n        prompt: Union[str, List[int]],\n        model: Optional[str] = None,\n        best_of: int = 1,\n        echo: bool = False,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        logprobs: Optional[int] = None,\n        logit_bias: Optional[Dict[int, float]] = None,\n        max_tokens: Optional[int] = None,\n        n: int = 1,\n        seed: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream: bool = False,\n        stream_options: Optional[Dict[str, Any]] = None,\n        suffix: Optional[str] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        user: Optional[str] = None,\n        response_format: Optional[Dict[str, Any]] = None,\n        request_id: Optional[str] = None,\n        extra_body: Optional[Dict[str, Any]] = None,\n    ) -> Union[\n        AsyncGenerator[openai_api_protocol.CompletionResponse, Any],\n        openai_api_protocol.CompletionResponse,\n    ]:\n        \"\"\"Asynchronous completion interface with OpenAI API compatibility.\n\n        See https://platform.openai.com/docs/api-reference/completions/create for specification.\n\n        Parameters\n        ----------\n        request_id : Optional[str]\n            The optional request id.\n            A random one will be generated if it is not given.\n\n        extra_body: Optional[Dict[str, Any]] = None,\n            Extra body options to pass to the request.\n            Can be used to pass debug config as extra_body[\"debug_config\"]\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n        return await self.engine()._completion(  # pylint: disable=protected-access\n            model=model,\n            prompt=prompt,\n            best_of=best_of,\n            echo=echo,\n            frequency_penalty=frequency_penalty,\n            presence_penalty=presence_penalty,\n            logprobs=logprobs,\n            logit_bias=logit_bias,\n            max_tokens=max_tokens,\n            n=n,\n            seed=seed,\n            stop=stop,\n            stream=stream,\n            stream_options=(\n                openai_api_protocol.StreamOptions.model_validate(stream_options)\n                if stream_options is not None\n                else None\n            ),\n            suffix=suffix,\n            temperature=temperature,\n            top_p=top_p,\n            user=user,\n            response_format=response_format,\n            request_id=request_id,\n            debug_config=(extra_body.get(\"debug_config\", None) if extra_body is not None else None),\n        )\n\n\nclass Completion:  # pylint: disable=too-few-public-methods\n    \"\"\"The proxy class to direct to completions.\"\"\"\n\n    if sys.version_info >= (3, 9):\n        engine: weakref.ReferenceType[\"MLCEngine\"]\n    else:\n        engine: weakref.ReferenceType\n\n    def __init__(self, engine: weakref.ReferenceType) -> None:\n        self.engine = engine\n\n    @overload\n    def create(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        *,\n        prompt: Union[str, List[int]],\n        stream: Literal[True],\n        model: Optional[str] = None,\n        best_of: int = 1,\n        echo: bool = False,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        logprobs: Optional[int] = None,\n        logit_bias: Optional[Dict[int, float]] = None,\n        max_tokens: Optional[int] = None,\n        n: int = 1,\n        seed: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream_options: Optional[Dict[str, Any]] = None,\n        suffix: Optional[str] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        user: Optional[str] = None,\n        response_format: Optional[Dict[str, Any]] = None,\n        request_id: Optional[str] = None,\n        extra_body: Optional[Dict[str, Any]] = None,\n    ) -> Iterator[openai_api_protocol.CompletionResponse]:\n        \"\"\"Synchronous streaming completion interface with OpenAI API compatibility.\n        The method streams back CompletionResponse that conforms to\n        OpenAI API one at a time via yield.\n\n        See https://platform.openai.com/docs/api-reference/completions/create for specification.\n\n        Parameters\n        ----------\n        request_id : Optional[str]\n            The optional request id.\n            A random one will be generated if it is not given.\n\n        extra_body: Optional[Dict[str, Any]] = None,\n            Extra body options to pass to the request.\n            Can be used to pass debug config as extra_body[\"debug_config\"]\n\n        Yields\n        ------\n        stream_response : CompletionResponse\n            The stream response conforming to OpenAI API.\n            See mlc_llm/protocol/openai_api_protocol.py or\n            https://platform.openai.com/docs/api-reference/completions/object for specification.\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n\n    @overload\n    def create(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        *,\n        prompt: Union[str, List[int]],\n        model: Optional[str] = None,\n        best_of: int = 1,\n        echo: bool = False,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        logprobs: Optional[int] = None,\n        logit_bias: Optional[Dict[int, float]] = None,\n        max_tokens: Optional[int] = None,\n        n: int = 1,\n        seed: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream: Literal[False] = False,\n        stream_options: Literal[None] = None,\n        suffix: Optional[str] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        user: Optional[str] = None,\n        response_format: Optional[Dict[str, Any]] = None,\n        request_id: Optional[str] = None,\n        extra_body: Optional[Dict[str, Any]] = None,\n    ) -> openai_api_protocol.CompletionResponse:\n        \"\"\"Synchronous non-streaming completion interface with OpenAI API compatibility.\n\n        See https://platform.openai.com/docs/api-reference/completions/create for specification.\n\n        Parameters\n        ----------\n        request_id : Optional[str]\n            The optional request id.\n            A random one will be generated if it is not given.\n\n        extra_body: Optional[Dict[str, Any]] = None,\n            Extra body options to pass to the request.\n            Can be used to pass debug config as extra_body[\"debug_config\"]\n\n        Returns\n        -------\n        response : CompletionResponse\n            The completion response conforming to OpenAI API.\n            See mlc_llm/protocol/openai_api_protocol.py or\n            https://platform.openai.com/docs/api-reference/completions/object for specification.\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n\n    def create(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        *,\n        prompt: Union[str, List[int]],\n        model: Optional[str] = None,\n        best_of: int = 1,\n        echo: bool = False,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        logprobs: Optional[int] = None,\n        logit_bias: Optional[Dict[int, float]] = None,\n        max_tokens: Optional[int] = None,\n        n: int = 1,\n        seed: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream: bool = False,\n        stream_options: Optional[Dict[str, Any]] = None,\n        suffix: Optional[str] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        user: Optional[str] = None,\n        response_format: Optional[Dict[str, Any]] = None,\n        request_id: Optional[str] = None,\n        extra_body: Optional[Dict[str, Any]] = None,\n    ) -> Union[\n        Iterator[openai_api_protocol.CompletionResponse],\n        openai_api_protocol.CompletionResponse,\n    ]:\n        \"\"\"Synchronous completion interface with OpenAI API compatibility.\n\n        See https://platform.openai.com/docs/api-reference/completions/create for specification.\n\n        Parameters\n        ----------\n        request_id : Optional[str]\n            The optional request id.\n            A random one will be generated if it is not given.\n\n        extra_body: Optional[Dict[str, Any]] = None,\n            Extra body options to pass to the request.\n            Can be used to pass debug config as extra_body[\"debug_config\"]\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n        return self.engine()._completion(  # pylint: disable=protected-access\n            model=model,\n            prompt=prompt,\n            best_of=best_of,\n            echo=echo,\n            frequency_penalty=frequency_penalty,\n            presence_penalty=presence_penalty,\n            logprobs=logprobs,\n            logit_bias=logit_bias,\n            max_tokens=max_tokens,\n            n=n,\n            seed=seed,\n            stop=stop,\n            stream=stream,\n            stream_options=(\n                openai_api_protocol.StreamOptions.model_validate(stream_options)\n                if stream_options is not None\n                else None\n            ),\n            suffix=suffix,\n            temperature=temperature,\n            top_p=top_p,\n            user=user,\n            response_format=response_format,\n            request_id=request_id,\n            debug_config=(extra_body.get(\"debug_config\", None) if extra_body is not None else None),\n        )\n\n\nclass AsyncMLCEngine(engine_base.MLCEngineBase):\n    \"\"\"The AsyncMLCEngine in MLC LLM that provides the asynchronous\n    interfaces with regard to OpenAI API.\n\n    Parameters\n    ----------\n    model : str\n        A path to ``mlc-chat-config.json``, or an MLC model directory that contains\n        `mlc-chat-config.json`.\n        It can also be a link to a HF repository pointing to an MLC compiled model.\n\n    device: Union[str, Device]\n        The device used to deploy the model such as \"cuda\" or \"cuda:0\".\n        Will default to \"auto\" and detect from local available GPUs if not specified.\n\n    model_lib : Optional[str]\n        The full path to the model library file to use (e.g. a ``.so`` file).\n        If unspecified, we will use the provided ``model`` to search over possible paths.\n        It the model lib is not found, it will be compiled in a JIT manner.\n\n    mode : Literal[\"local\", \"interactive\", \"server\"]\n        The engine mode in MLC LLM.\n        We provide three preset modes: \"local\", \"interactive\" and \"server\".\n        The default mode is \"local\".\n        The choice of mode decides the values of \"max_num_sequence\", \"max_total_sequence_length\"\n        and \"prefill_chunk_size\" when they are not explicitly specified.\n        1. Mode \"local\" refers to the local server deployment which has low\n        request concurrency. So the max batch size will be set to 4, and max\n        total sequence length and prefill chunk size are set to the context\n        window size (or sliding window size) of the model.\n        2. Mode \"interactive\" refers to the interactive use of server, which\n        has at most 1 concurrent request. So the max batch size will be set to 1,\n        and max total sequence length and prefill chunk size are set to the context\n        window size (or sliding window size) of the model.\n        3. Mode \"server\" refers to the large server use case which may handle\n        many concurrent request and want to use GPU memory as much as possible.\n        In this mode, we will automatically infer the largest possible max batch\n        size and max total sequence length.\n\n        You can manually specify arguments \"max_num_sequence\", \"max_total_sequence_length\" and\n        \"prefill_chunk_size\" to override the automatic inferred values.\n\n    engine_config : Optional[EngineConfig]\n        Additional configurable arguments of MLC engine.\n        See class \"EngineConfig\" for more detail.\n\n    enable_tracing : bool\n        A boolean indicating if to enable event logging for requests.\n    \"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        model: str,\n        device: Union[str, Device] = \"auto\",\n        *,\n        model_lib: Optional[str] = None,\n        mode: Literal[\"local\", \"interactive\", \"server\"] = \"local\",\n        engine_config: Optional[EngineConfig] = None,\n        enable_tracing: bool = False,\n    ) -> None:\n        super().__init__(\n            \"async\",\n            model=model,\n            device=device,\n            model_lib=model_lib,\n            mode=mode,\n            engine_config=engine_config,\n            enable_tracing=enable_tracing,\n        )\n        self.chat = AsyncChat(weakref.ref(self))\n        self.completions = AsyncCompletion(weakref.ref(self))\n\n    async def abort(self, request_id: str) -> None:\n        \"\"\"Generation abortion interface.\n\n        Parameters\n        ---------\n        request_id : str\n            The id of the request to abort.\n        \"\"\"\n        self._abort(request_id)\n\n    async def metrics(self) -> engine_base.EngineMetrics:\n        \"\"\"Get engine metrics\n\n        Returns\n        -------\n        metrics: EngineMetrics\n            The engine metrics\n        \"\"\"\n        # pylint: disable=protected-access\n        return await engine_base._async_query_engine_metrics(self)\n\n    async def _chat_completion(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        *,\n        messages: List[Dict[str, Any]],\n        model: Optional[str] = None,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        logprobs: bool = False,\n        top_logprobs: int = 0,\n        logit_bias: Optional[Dict[int, float]] = None,\n        max_tokens: Optional[int] = None,\n        n: int = 1,\n        seed: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream: bool = False,\n        stream_options: Optional[Dict[str, Any]] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        tools: Optional[List[Dict[str, Any]]] = None,\n        tool_choice: Optional[Union[Literal[\"none\", \"auto\"], Dict]] = None,\n        user: Optional[str] = None,\n        response_format: Optional[Dict[str, Any]] = None,\n        request_id: Optional[str] = None,\n        debug_config: Optional[Dict[str, Any]] = None,\n    ) -> Union[\n        AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any],\n        openai_api_protocol.ChatCompletionResponse,\n    ]:\n        \"\"\"Asynchronous chat completion internal interface with OpenAI API compatibility.\n\n        See https://platform.openai.com/docs/api-reference/chat/create for specification.\n\n        Parameters\n        ----------\n        request_id : Optional[str]\n            The optional request id.\n            A random one will be generated if it is not given.\n            Extra body options to pass to the request.\n            Can be used to pass debug config as extra_body[\"debug_config\"]\n\n        debug_config: Optional[Dict[str, Any]] = None,\n            Debug config body options to pass to the request.\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n        if request_id is None:\n            request_id = f\"chatcmpl-{engine_utils.random_uuid()}\"\n\n        chatcmpl_generator = self._handle_chat_completion(\n            openai_api_protocol.ChatCompletionRequest(\n                messages=[\n                    openai_api_protocol.ChatCompletionMessage.model_validate(message)\n                    for message in messages\n                ],\n                model=model,\n                frequency_penalty=frequency_penalty,\n                presence_penalty=presence_penalty,\n                logprobs=logprobs,\n                top_logprobs=top_logprobs,\n                logit_bias=logit_bias,\n                max_tokens=max_tokens,\n                n=n,\n                seed=seed,\n                stop=stop,\n                stream=stream,\n                stream_options=(\n                    openai_api_protocol.StreamOptions.model_validate(stream_options)\n                    if stream_options is not None\n                    else None\n                ),\n                temperature=temperature,\n                top_p=top_p,\n                tools=(\n                    [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools]\n                    if tools is not None\n                    else None\n                ),\n                tool_choice=tool_choice,\n                user=user,\n                response_format=(\n                    openai_api_protocol.RequestResponseFormat.model_validate(response_format)\n                    if response_format is not None\n                    else None\n                ),\n                debug_config=(\n                    debug_protocol.DebugConfig.model_validate(debug_config)\n                    if debug_config is not None\n                    else None\n                ),\n            ),\n            request_id=request_id,\n            request_final_usage_include_extra=True,\n        )\n        if stream:\n            # Stream response.\n            return chatcmpl_generator\n        # Normal response.\n        output_texts = [\"\" for _ in range(n)]\n        finish_reasons: List[Optional[str]] = [None for _ in range(n)]\n        logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = (\n            [[] for _ in range(n)] if logprobs else None\n        )\n        request_final_usage = None\n        try:\n            async for response in chatcmpl_generator:\n                # when usage is not None this is the last chunk\n                if response.usage is not None:\n                    request_final_usage = response.usage\n                    continue\n                for choice in response.choices:\n                    assert isinstance(choice.delta.content, str)\n                    output_texts[choice.index] += choice.delta.content\n                    if choice.finish_reason is not None and finish_reasons[choice.index] is None:\n                        finish_reasons[choice.index] = choice.finish_reason\n                    if choice.logprobs is not None:\n                        assert logprob_results is not None\n                        logprob_results[  # pylint: disable=unsupported-assignment-operation\n                            choice.index\n                        ] += choice.logprobs.content\n        except asyncio.CancelledError:  # pylint: disable=try-except-raise\n            # for cancelled error, we can simply pass it through\n            raise\n        except Exception as err:  # pylint: disable=broad-exception-caught\n            logger.error(\"Error in chat completion with request ID %s: %s\", request_id, err)\n            raise\n\n        assert all(finish_reason is not None for finish_reason in finish_reasons)\n        use_function_calling, tool_calls_list = engine_base.process_function_call_output(\n            output_texts, finish_reasons\n        )\n        return engine_base.wrap_chat_completion_response(\n            request_id=request_id,\n            model=model,\n            output_texts=output_texts,\n            finish_reasons=finish_reasons,\n            tool_calls_list=tool_calls_list,\n            logprob_results=logprob_results,\n            use_function_calling=use_function_calling,\n            usage=request_final_usage,\n        )\n\n    async def _completion(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        *,\n        prompt: Union[str, List[int]],\n        model: Optional[str] = None,\n        best_of: int = 1,\n        echo: bool = False,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        logprobs: Optional[int] = None,\n        logit_bias: Optional[Dict[int, float]] = None,\n        max_tokens: Optional[int] = None,\n        n: int = 1,\n        seed: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream: bool = False,\n        stream_options: Optional[Dict[str, Any]] = None,\n        suffix: Optional[str] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        user: Optional[str] = None,\n        response_format: Optional[Dict[str, Any]] = None,\n        request_id: Optional[str] = None,\n        debug_config: Optional[Dict[str, Any]] = None,\n    ) -> Union[\n        AsyncGenerator[openai_api_protocol.CompletionResponse, Any],\n        openai_api_protocol.CompletionResponse,\n    ]:\n        \"\"\"Asynchronous completion internal interface with OpenAI API compatibility.\n\n        See https://platform.openai.com/docs/api-reference/completions/create for specification.\n\n        Parameters\n        ----------\n        request_id : Optional[str]\n            The optional request id.\n            A random one will be generated if it is not given.\n\n        debug_config: Optional[Dict[str, Any]] = None,\n            Extra debug options to pass to the request.\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n        if request_id is None:\n            request_id = f\"cmpl-{engine_utils.random_uuid()}\"\n        cmpl_generator = self._handle_completion(\n            openai_api_protocol.CompletionRequest(\n                model=model,\n                prompt=prompt,\n                best_of=best_of,\n                echo=echo,\n                frequency_penalty=frequency_penalty,\n                presence_penalty=presence_penalty,\n                logprobs=logprobs,\n                logit_bias=logit_bias,\n                max_tokens=max_tokens,\n                n=n,\n                seed=seed,\n                stop=stop,\n                stream=stream,\n                stream_options=(\n                    openai_api_protocol.StreamOptions.model_validate(stream_options)\n                    if stream_options is not None\n                    else None\n                ),\n                suffix=suffix,\n                temperature=temperature,\n                top_p=top_p,\n                user=user,\n                response_format=(\n                    openai_api_protocol.RequestResponseFormat.model_validate(response_format)\n                    if response_format is not None\n                    else None\n                ),\n                debug_config=(\n                    debug_protocol.DebugConfig.model_validate(debug_config)\n                    if debug_config is not None\n                    else None\n                ),\n            ),\n            request_id=request_id,\n            request_final_usage_include_extra=True,\n        )\n        if stream:\n            # Stream response.\n            return cmpl_generator\n        # Normal response.\n        request_final_usage = None\n        output_texts = [\"\"] * n\n        finish_reasons: List[Optional[str]] = [None] * n\n        logprob_results: List[Optional[openai_api_protocol.CompletionLogProbs]] = [None] * n\n\n        async for response in cmpl_generator:\n            # this is the final chunk\n            if response.usage is not None:\n                request_final_usage = response.usage\n                continue\n            for choice in response.choices:\n                output_texts[choice.index] += choice.text\n                if choice.finish_reason is not None and finish_reasons[choice.index] is None:\n                    finish_reasons[choice.index] = choice.finish_reason\n                if choice.logprobs is not None:\n                    logprob_results[choice.index] = choice.logprobs\n\n        assert all(finish_reason is not None for finish_reason in finish_reasons)\n\n        return engine_base.wrap_completion_response(\n            request_id=request_id,\n            model=model,\n            output_texts=output_texts,\n            finish_reasons=finish_reasons,\n            logprob_results=logprob_results,\n            usage=request_final_usage,\n        )\n\n    async def _handle_chat_completion(\n        self,\n        request: openai_api_protocol.ChatCompletionRequest,\n        request_id: str,\n        request_final_usage_include_extra: bool,\n    ) -> AsyncGenerator[openai_api_protocol.ChatCompletionStreamResponse, Any]:\n        \"\"\"The implementation fo asynchronous ChatCompletionRequest handling.\n\n        Yields\n        ------\n        stream_response : ChatCompletionStreamResponse\n            The stream response conforming to OpenAI API.\n            See mlc_llm/protocol/openai_api_protocol.py or\n            https://platform.openai.com/docs/api-reference/chat/streaming for specification.\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n        (\n            prompts,\n            generation_cfg,\n            use_function_calling,\n            prompt_length,\n        ) = engine_base.process_chat_completion_request(\n            request,\n            request_id,\n            self.state,\n            self.model_config_dicts[0],\n            self.tokenizer.encode,\n            self.max_input_sequence_length,\n            self.conv_template.model_copy(deep=True),\n        )\n        # prompt length is not used\n        _ = prompt_length\n        finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)]\n        self.state.record_event(request_id, event=\"invoke generate\")\n        try:\n            async for delta_outputs in self._generate(\n                prompts,  # type: ignore[arg-type]\n                generation_cfg,\n                request_id,  # type: ignore\n            ):\n                response = engine_base.process_chat_completion_stream_output(\n                    delta_outputs,\n                    request,\n                    request_id,\n                    self.state,\n                    use_function_calling,\n                    finish_reasons,\n                )\n\n                if response is not None:\n                    if response.usage is not None:\n                        if not request_final_usage_include_extra:\n                            response.usage.extra = None\n                    yield response\n            self.state.record_event(request_id, event=\"finish\")\n        except asyncio.CancelledError:  # pylint: disable=try-except-raise\n            # for cancelled error, we can simply pass it through\n            raise\n        except Exception as err:  # pylint: disable=broad-exception-caught\n            logger.error(\"Error in _handle_chat_completion for request %s: %s\", request_id, err)\n            raise\n\n    async def _handle_completion(\n        self,\n        request: openai_api_protocol.CompletionRequest,\n        request_id: str,\n        request_final_usage_include_extra: bool,\n    ) -> AsyncGenerator[openai_api_protocol.CompletionResponse, Any]:\n        \"\"\"The implementation fo asynchronous CompletionRequest handling.\n\n        Yields\n        ------\n        stream_response : CompletionResponse\n            The stream response conforming to OpenAI API.\n            See mlc_llm/protocol/openai_api_protocol.py or\n            https://platform.openai.com/docs/api-reference/completions/object for specification.\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n        (\n            prompt,\n            generation_cfg,\n            prompt_length,\n            echo_response,\n        ) = engine_base.process_completion_request(\n            request,\n            request_id,\n            self.state,\n            self.tokenizer,\n            self.max_input_sequence_length,\n            self.conv_template.model_copy(deep=True),\n        )\n        _ = prompt_length\n        if echo_response is not None:\n            yield echo_response\n\n        finish_reasons: List[Optional[str]] = [None] * generation_cfg.n\n        self.state.record_event(request_id, event=\"invoke generate\")\n        try:\n            async for delta_outputs in self._generate(\n                prompt,\n                generation_cfg,\n                request_id,  # type: ignore\n            ):\n                response = engine_base.process_completion_stream_output(\n                    delta_outputs,\n                    request,\n                    request_id,\n                    self.state,\n                    finish_reasons,\n                )\n\n                if response is not None:\n                    if response.usage is not None:\n                        if not request_final_usage_include_extra:\n                            response.usage.extra = None\n                    yield response\n\n            suffix_response = engine_base.create_completion_suffix_response(\n                request, request_id, finish_reasons\n            )\n            if suffix_response is not None:\n                yield suffix_response\n            self.state.record_event(request_id, event=\"finish\")\n        except asyncio.CancelledError:  # pylint: disable=try-except-raise\n            # for cancelled error, we can simply pass it through\n            raise\n        except Exception as err:  # pylint: disable=broad-exception-caught\n            logger.error(\"Error in _handle_completion for request %s: %s\", request_id, err)\n            raise\n\n    async def _generate(\n        self,\n        prompt: Union[str, List[int], List[Union[str, List[int], data.Data]]],\n        generation_config: GenerationConfig,\n        request_id: str,\n    ) -> AsyncGenerator[List[engine_base.CallbackStreamOutput], Any]:\n        \"\"\"Internal asynchronous text generation interface of AsyncMLCEngine.\n        The method is a coroutine that streams a list of CallbackStreamOutput\n        at a time via yield. The returned list length is the number of\n        parallel generations specified by `generation_config.n`.\n\n        Parameters\n        ----------\n        prompt : Union[str, List[int], List[Union[str, List[int], data.Data]]]\n            The input prompt in forms of text strings, lists of token ids or data.\n\n        generation_config : GenerationConfig\n            The generation config of the request.\n\n        request_id : str\n            The unique identifier (in string) or this generation request.\n\n        Yields\n        ------\n        request_output : List[engine_base.CallbackStreamOutput]\n            The delta generated outputs in a list.\n            The number of list elements equals to `generation_config.n`,\n            and each element corresponds to the delta output of a parallel\n            generation.\n        \"\"\"\n        if self._terminated:\n            raise ValueError(\"The AsyncThreadedEngine has terminated.\")\n        self.state.async_lazy_init_event_loop()\n\n        # Create the request with the given id, input data, generation\n        # config and the created callback.\n        input_data = engine_utils.convert_prompts_to_data(prompt)\n        request = self._ffi[\"create_request\"](\n            request_id, input_data, generation_config.model_dump_json(by_alias=True)\n        )\n\n        # Create the unique async request stream of the request.\n        stream = engine_base.AsyncRequestStream()\n        if request_id in self.state.async_streamers:\n            # Report error in the stream if the request id already exists.\n            stream.push(\n                RuntimeError(\n                    f'The request id \"{request_id} already exists. '\n                    'Please make sure the request id is unique.\"'\n                )\n            )\n        else:\n            # Record the stream in the tracker\n            self.state.async_streamers[request_id] = (\n                stream,\n                [TextStreamer(self.tokenizer) for _ in range(generation_config.n)],\n            )\n            self._ffi[\"add_request\"](request)\n\n        def abort_request():\n            \"\"\"clean up\"\"\"\n            self._abort(request_id)\n            logger.info(\"request %s cancelled\", request_id)\n\n        with engine_utils.ErrorCleanupScope(abort_request):\n            # Iterate the stream asynchronously and yield the output.\n            try:\n                async for request_output in stream:\n                    yield request_output\n            except asyncio.CancelledError:  # pylint: disable=try-except-raise\n                # for cancelled error, we can simply pass it through\n                raise\n            except Exception as exception:  # pylint: disable=broad-exception-caught\n                logger.error(\"Exception in _generate for request %s: %s\", request_id, exception)\n                raise\n\n    def _abort(self, request_id: str):\n        \"\"\"Internal implementation of request abortion.\"\"\"\n        self.state.async_streamers.pop(request_id, None)\n        self._ffi[\"abort_request\"](request_id)\n\n\nclass MLCEngine(engine_base.MLCEngineBase):\n    \"\"\"The MLCEngine in MLC LLM that provides the synchronous\n    interfaces with regard to OpenAI API.\n\n    Parameters\n    ----------\n    model : str\n        A path to ``mlc-chat-config.json``, or an MLC model directory that contains\n        `mlc-chat-config.json`.\n        It can also be a link to a HF repository pointing to an MLC compiled model.\n\n    device: Union[str, Device]\n        The device used to deploy the model such as \"cuda\" or \"cuda:0\".\n        Will default to \"auto\" and detect from local available GPUs if not specified.\n\n    model_lib : Optional[str]\n        The full path to the model library file to use (e.g. a ``.so`` file).\n        If unspecified, we will use the provided ``model`` to search over possible paths.\n        It the model lib is not found, it will be compiled in a JIT manner.\n\n    mode : Literal[\"local\", \"interactive\", \"server\"]\n        The engine mode in MLC LLM.\n        We provide three preset modes: \"local\", \"interactive\" and \"server\".\n        The default mode is \"local\".\n        The choice of mode decides the values of \"max_num_sequence\", \"max_total_sequence_length\"\n        and \"prefill_chunk_size\" when they are not explicitly specified.\n        1. Mode \"local\" refers to the local server deployment which has low\n        request concurrency. So the max batch size will be set to 4, and max\n        total sequence length and prefill chunk size are set to the context\n        window size (or sliding window size) of the model.\n        2. Mode \"interactive\" refers to the interactive use of server, which\n        has at most 1 concurrent request. So the max batch size will be set to 1,\n        and max total sequence length and prefill chunk size are set to the context\n        window size (or sliding window size) of the model.\n        3. Mode \"server\" refers to the large server use case which may handle\n        many concurrent request and want to use GPU memory as much as possible.\n        In this mode, we will automatically infer the largest possible max batch\n        size and max total sequence length.\n\n        You can manually specify arguments \"max_num_sequence\", \"max_total_sequence_length\" and\n        \"prefill_chunk_size\" to override the automatic inferred values.\n\n    engine_config : Optional[EngineConfig]\n        Additional configurable arguments of MLC engine.\n        See class \"EngineConfig\" for more detail.\n\n    enable_tracing : bool\n        A boolean indicating if to enable event logging for requests.\n    \"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        model: str,\n        device: Union[str, Device] = \"auto\",\n        *,\n        model_lib: Optional[str] = None,\n        mode: Literal[\"local\", \"interactive\", \"server\"] = \"local\",\n        engine_config: Optional[EngineConfig] = None,\n        enable_tracing: bool = False,\n    ) -> None:\n        super().__init__(\n            \"sync\",\n            model=model,\n            device=device,\n            model_lib=model_lib,\n            mode=mode,\n            engine_config=engine_config,\n            enable_tracing=enable_tracing,\n        )\n        self.chat = Chat(weakref.ref(self))\n        self.completions = Completion(weakref.ref(self))\n\n    def abort(self, request_id: str) -> None:\n        \"\"\"Generation abortion interface.\n\n        Parameters\n        ---------\n        request_id : str\n            The id of the request to abort.\n        \"\"\"\n        self._ffi[\"abort_request\"](request_id)\n\n    def metrics(self) -> engine_base.EngineMetrics:\n        \"\"\"Get engine metrics\n\n        Returns\n        -------\n        metrics: EngineMetrics\n            The engine metrics\n        \"\"\"\n        # pylint: disable=protected-access\n        return engine_base._query_engine_metrics(self)\n\n    def _chat_completion(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        *,\n        messages: List[Dict[str, Any]],\n        model: Optional[str] = None,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        logprobs: bool = False,\n        top_logprobs: int = 0,\n        logit_bias: Optional[Dict[int, float]] = None,\n        max_tokens: Optional[int] = None,\n        n: int = 1,\n        seed: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream: bool = False,\n        stream_options: Optional[Dict[str, Any]] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        tools: Optional[List[Dict[str, Any]]] = None,\n        tool_choice: Optional[Union[Literal[\"none\", \"auto\"], Dict]] = None,\n        user: Optional[str] = None,\n        response_format: Optional[Dict[str, Any]] = None,\n        request_id: Optional[str] = None,\n        debug_config: Optional[Dict[str, Any]] = None,\n    ) -> Union[\n        Iterator[openai_api_protocol.ChatCompletionStreamResponse],\n        openai_api_protocol.ChatCompletionResponse,\n    ]:\n        \"\"\"Synchronous chat completion internal interface with OpenAI API compatibility.\n\n        See https://platform.openai.com/docs/api-reference/chat/create for specification.\n\n        Parameters\n        ----------\n        request_id : Optional[str]\n            The optional request id.\n            A random one will be generated if it is not given.\n\n        debug_config: Optional[Dict[str, Any]] = None,\n            Extra debug options to pass to the request.\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n        if request_id is None:\n            request_id = f\"chatcmpl-{engine_utils.random_uuid()}\"\n\n        chatcmpl_generator = self._handle_chat_completion(\n            openai_api_protocol.ChatCompletionRequest(\n                messages=[\n                    openai_api_protocol.ChatCompletionMessage.model_validate(message)\n                    for message in messages\n                ],\n                model=model,\n                frequency_penalty=frequency_penalty,\n                presence_penalty=presence_penalty,\n                logprobs=logprobs,\n                top_logprobs=top_logprobs,\n                logit_bias=logit_bias,\n                max_tokens=max_tokens,\n                n=n,\n                seed=seed,\n                stop=stop,\n                stream=stream,\n                stream_options=(\n                    openai_api_protocol.StreamOptions.model_validate(stream_options)\n                    if stream_options is not None\n                    else None\n                ),\n                temperature=temperature,\n                top_p=top_p,\n                tools=(\n                    [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools]\n                    if tools is not None\n                    else None\n                ),\n                tool_choice=tool_choice,\n                user=user,\n                response_format=(\n                    openai_api_protocol.RequestResponseFormat.model_validate(response_format)\n                    if response_format is not None\n                    else None\n                ),\n                debug_config=(\n                    debug_protocol.DebugConfig.model_validate(debug_config)\n                    if debug_config is not None\n                    else None\n                ),\n            ),\n            request_id=request_id,\n        )\n        if stream:\n            # Stream response.\n            return chatcmpl_generator\n        # Normal response.\n        request_final_usage = None\n        output_texts = [\"\" for _ in range(n)]\n        finish_reasons: List[Optional[str]] = [None for _ in range(n)]\n        logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]] = (\n            [[] for _ in range(n)] if logprobs else None\n        )\n        for response in chatcmpl_generator:\n            # if usage is not None, this is the last chunk\n            if response.usage is not None:\n                request_final_usage = response.usage\n                continue\n            for choice in response.choices:\n                assert isinstance(choice.delta.content, str)\n                output_texts[choice.index] += choice.delta.content\n                if choice.finish_reason is not None and finish_reasons[choice.index] is None:\n                    finish_reasons[choice.index] = choice.finish_reason\n                if choice.logprobs is not None:\n                    assert logprob_results is not None\n                    logprob_results[  # pylint: disable=unsupported-assignment-operation\n                        choice.index\n                    ] += choice.logprobs.content\n\n        assert all(finish_reason is not None for finish_reason in finish_reasons)\n        use_function_calling, tool_calls_list = engine_base.process_function_call_output(\n            output_texts, finish_reasons\n        )\n        return engine_base.wrap_chat_completion_response(\n            request_id=request_id,\n            model=model,\n            output_texts=output_texts,\n            finish_reasons=finish_reasons,\n            tool_calls_list=tool_calls_list,\n            logprob_results=logprob_results,\n            use_function_calling=use_function_calling,\n            usage=request_final_usage,\n        )\n\n    def _completion(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        *,\n        prompt: Union[str, List[int]],\n        model: Optional[str] = None,\n        best_of: int = 1,\n        echo: bool = False,\n        frequency_penalty: Optional[float] = None,\n        presence_penalty: Optional[float] = None,\n        logprobs: Optional[int] = None,\n        logit_bias: Optional[Dict[int, float]] = None,\n        max_tokens: Optional[int] = None,\n        n: int = 1,\n        seed: Optional[int] = None,\n        stop: Optional[Union[str, List[str]]] = None,\n        stream: bool = False,\n        stream_options: Optional[Dict[str, Any]] = None,\n        suffix: Optional[str] = None,\n        temperature: Optional[float] = None,\n        top_p: Optional[float] = None,\n        user: Optional[str] = None,\n        response_format: Optional[Dict[str, Any]] = None,\n        request_id: Optional[str] = None,\n        debug_config: Optional[Dict[str, Any]] = None,\n    ) -> Union[\n        Iterator[openai_api_protocol.CompletionResponse],\n        openai_api_protocol.CompletionResponse,\n    ]:\n        \"\"\"Synchronous completion internal interface with OpenAI API compatibility.\n\n        See https://platform.openai.com/docs/api-reference/completions/create for specification.\n\n        Parameters\n        ----------\n        request_id : Optional[str]\n            The optional request id.\n            A random one will be generated if it is not given.\n\n        debug_config: Optional[Dict[str, Any]] = None,\n            Extra debug options to pass to the request.\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n        if request_id is None:\n            request_id = f\"cmpl-{engine_utils.random_uuid()}\"\n\n        cmpl_generator = self._handle_completion(\n            openai_api_protocol.CompletionRequest(\n                model=model,\n                prompt=prompt,\n                best_of=best_of,\n                echo=echo,\n                frequency_penalty=frequency_penalty,\n                presence_penalty=presence_penalty,\n                logprobs=logprobs,\n                logit_bias=logit_bias,\n                max_tokens=max_tokens,\n                n=n,\n                seed=seed,\n                stop=stop,\n                stream=stream,\n                stream_options=(\n                    openai_api_protocol.StreamOptions.model_validate(stream_options)\n                    if stream_options is not None\n                    else None\n                ),\n                suffix=suffix,\n                temperature=temperature,\n                top_p=top_p,\n                user=user,\n                response_format=(\n                    openai_api_protocol.RequestResponseFormat.model_validate(response_format)\n                    if response_format is not None\n                    else None\n                ),\n                debug_config=(\n                    debug_protocol.DebugConfig.model_validate(debug_config)\n                    if debug_config is not None\n                    else None\n                ),\n            ),\n            request_id=request_id,\n        )\n        if stream:\n            # Stream response.\n            return cmpl_generator\n        # Normal response.\n        request_final_usage = None\n        output_texts = [\"\"] * n\n        finish_reasons: List[Optional[str]] = [None] * n\n        logprob_results: List[Optional[openai_api_protocol.CompletionLogProbs]] = [None] * n\n\n        for response in cmpl_generator:\n            # this is the final chunk\n            if response.usage is not None:\n                request_final_usage = response.usage\n                continue\n            for choice in response.choices:\n                output_texts[choice.index] += choice.text\n                if choice.finish_reason is not None and finish_reasons[choice.index] is None:\n                    finish_reasons[choice.index] = choice.finish_reason\n                if choice.logprobs is not None:\n                    logprob_results[choice.index] = choice.logprobs\n\n        assert all(finish_reason is not None for finish_reason in finish_reasons)\n        return engine_base.wrap_completion_response(\n            request_id=request_id,\n            model=model,\n            output_texts=output_texts,\n            finish_reasons=finish_reasons,\n            logprob_results=logprob_results,\n            usage=request_final_usage,\n        )\n\n    def _handle_chat_completion(\n        self, request: openai_api_protocol.ChatCompletionRequest, request_id: str\n    ) -> Iterator[openai_api_protocol.ChatCompletionStreamResponse]:\n        \"\"\"The implementation fo synchronous ChatCompletionRequest handling.\n\n        Yields\n        ------\n        stream_response : CompletionResponse\n            The stream response conforming to OpenAI API.\n            See mlc_llm/protocol/openai_api_protocol.py or\n            https://platform.openai.com/docs/api-reference/chat/streaming for specification.\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n        (\n            prompts,\n            generation_cfg,\n            use_function_calling,\n            prompt_length,\n        ) = engine_base.process_chat_completion_request(\n            request,\n            request_id,\n            self.state,\n            self.model_config_dicts[0],\n            self.tokenizer.encode,\n            self.max_input_sequence_length,\n            self.conv_template.model_copy(deep=True),\n        )\n        _ = prompt_length\n\n        finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)]\n        self.state.record_event(request_id, event=\"invoke generate\")\n        for delta_outputs in self._generate(prompts, generation_cfg, request_id):  # type: ignore\n            response = engine_base.process_chat_completion_stream_output(\n                delta_outputs,\n                request,\n                request_id,\n                self.state,\n                use_function_calling,\n                finish_reasons,\n            )\n            if response is not None:\n                yield response\n        self.state.record_event(request_id, event=\"finish\")\n\n    def _handle_completion(\n        self, request: openai_api_protocol.CompletionRequest, request_id: str\n    ) -> Iterator[openai_api_protocol.CompletionResponse]:\n        \"\"\"The implementation for synchronous CompletionRequest handling.\n\n        Yields\n        ------\n        stream_response : CompletionResponse\n            The stream response conforming to OpenAI API.\n            See mlc_llm/protocol/openai_api_protocol.py or\n            https://platform.openai.com/docs/api-reference/completions/object for specification.\n\n        Raises\n        ------\n        e : BadRequestError\n            BadRequestError is raised when the request is invalid.\n        \"\"\"\n        (\n            prompt,\n            generation_cfg,\n            prompt_length,\n            echo_response,\n        ) = engine_base.process_completion_request(\n            request,\n            request_id,\n            self.state,\n            self.tokenizer,\n            self.max_input_sequence_length,\n            self.conv_template.model_copy(deep=True),\n        )\n        _ = prompt_length\n        if echo_response is not None:\n            yield echo_response\n\n        finish_reasons: List[Optional[str]] = [None for _ in range(generation_cfg.n)]\n        self.state.record_event(request_id, event=\"invoke generate\")\n        for delta_outputs in self._generate(prompt, generation_cfg, request_id):  # type: ignore\n            response = engine_base.process_completion_stream_output(\n                delta_outputs,\n                request,\n                request_id,\n                self.state,\n                finish_reasons,\n            )\n            if response is not None:\n                yield response\n\n        suffix_response = engine_base.create_completion_suffix_response(\n            request, request_id, finish_reasons\n        )\n        if suffix_response is not None:\n            yield suffix_response\n        self.state.record_event(request_id, event=\"finish\")\n\n    def _generate(  # pylint: disable=too-many-locals\n        self,\n        prompt: Union[str, List[int], List[Union[str, List[int], data.Data]]],\n        generation_config: GenerationConfig,\n        request_id: str,\n    ) -> Iterator[List[engine_base.CallbackStreamOutput]]:\n        \"\"\"Internal synchronous text generation interface of MLCEngine.\n        The method is a coroutine that streams a list of CallbackStreamOutput\n        at a time via yield. The returned list length is the number of\n        parallel generations specified by `generation_config.n`\n        except for the final chunk(which is always an List of size 1 and comes with usage)\n\n        Parameters\n        ----------\n        prompt : Union[str, List[int], List[Union[str, List[int], data.Data]]]\n            The input prompt in forms of text strings, lists of token ids or data.\n\n        generation_config : GenerationConfig\n            The generation config of the request.\n\n        request_id : str\n            The unique identifier (in string) or this generation request.\n\n        Yields\n        ------\n        request_output : List[engine_base.CallbackStreamOutput]\n            The delta generated outputs in a list.\n            Except for the final chunk, the number of list elements equals to `generation_config.n`,\n            and each element corresponds to the delta output of a parallel generation.\n        \"\"\"\n        if self._terminated:\n            raise ValueError(\"The engine has terminated.\")\n\n        # Create the request with the given id, input data, generation\n        # config and the created callback.\n        input_data = engine_utils.convert_prompts_to_data(prompt)\n        request = self._ffi[\"create_request\"](\n            request_id, input_data, generation_config.model_dump_json(by_alias=True)\n        )\n\n        # Record the stream in the tracker\n        self.state.sync_output_queue = queue.Queue()\n        self.state.sync_text_streamers = [\n            TextStreamer(self.tokenizer) for _ in range(generation_config.n)\n        ]\n        self._ffi[\"add_request\"](request)\n\n        def abort_request():\n            \"\"\"clean up request if exception happens\"\"\"\n            self.abort(request_id)\n\n        # Iterate the stream asynchronously and yield the token.\n        with engine_utils.ErrorCleanupScope(abort_request):\n            while True:\n                delta_outputs = self.state.sync_output_queue.get()\n                request_outputs, request_final_usage_json_str = self._request_stream_callback_impl(\n                    delta_outputs\n                )\n                for request_output in request_outputs:  # pylint: disable=use-yield-from\n                    yield request_output\n\n                if request_final_usage_json_str is not None:\n                    # final chunk, we can break\n                    output = engine_base.CallbackStreamOutput(\n                        delta_text=\"\",\n                        delta_logprob_json_strs=None,\n                        finish_reason=None,\n                        request_final_usage_json_str=request_final_usage_json_str,\n                    )\n                    yield [output]\n                    break\n\n    def _request_stream_callback_impl(\n        self, delta_outputs: List[data.RequestStreamOutput]\n    ) -> Tuple[List[List[engine_base.CallbackStreamOutput]], Optional[str]]:\n        \"\"\"The underlying implementation of request stream callback of MLCEngine.\"\"\"\n        batch_outputs: List[List[engine_base.CallbackStreamOutput]] = []\n        for delta_output in delta_outputs:\n            request_id, stream_outputs = delta_output.unpack()\n            self.state.record_event(request_id, event=\"start callback\")\n\n            # final chunk is now always indicated by a chunk\n            # where usage json is present\n            # the backend engine always streams back this chunk\n            # regardless of include_usage option\n            is_final_chunk = stream_outputs[0].request_final_usage_json_str is not None\n            if is_final_chunk:\n                return (batch_outputs, stream_outputs[0].request_final_usage_json_str)\n\n            outputs: List[engine_base.CallbackStreamOutput] = []\n            for stream_output, text_streamer in zip(stream_outputs, self.state.sync_text_streamers):\n                self.state.record_event(request_id, event=\"start detokenization\")\n                delta_text = stream_output.extra_prefix_string + (\n                    text_streamer.put(stream_output.delta_token_ids)\n                    if len(stream_output.delta_token_ids) > 0\n                    else \"\"\n                )\n                if stream_output.finish_reason is not None:\n                    delta_text += text_streamer.finish()\n                self.state.record_event(request_id, event=\"finish detokenization\")\n\n                outputs.append(\n                    engine_base.CallbackStreamOutput(\n                        delta_text=delta_text,\n                        delta_logprob_json_strs=stream_output.delta_logprob_json_strs,\n                        finish_reason=stream_output.finish_reason,\n                        request_final_usage_json_str=None,\n                    )\n                )\n            batch_outputs.append(outputs)\n            self.state.record_event(request_id, event=\"finish callback\")\n        return (batch_outputs, None)\n"
  },
  {
    "path": "python/mlc_llm/serve/engine_base.py",
    "content": "\"\"\"The MLC LLM Serving engine base class.\"\"\"\n\n# pylint: disable=too-many-lines\n\nimport ast\nimport asyncio\nimport json\nimport numbers\nimport queue\nimport sys\nimport threading\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union\n\nimport tvm\nfrom tvm.runtime import Device\n\nfrom mlc_llm.protocol import openai_api_protocol\nfrom mlc_llm.protocol.conversation_protocol import Conversation\nfrom mlc_llm.protocol.generation_config import GenerationConfig\nfrom mlc_llm.protocol.mlc_chat_config import MLCChatConfig\nfrom mlc_llm.serve import data, engine_utils\nfrom mlc_llm.serve.config import EngineConfig\nfrom mlc_llm.serve.event_trace_recorder import EventTraceRecorder\nfrom mlc_llm.support import download_cache, logging\nfrom mlc_llm.support.auto_device import detect_device\nfrom mlc_llm.support.style import green\nfrom mlc_llm.tokenizers import TextStreamer, Tokenizer\n\nlogger = logging.getLogger(__name__)\n\n\n@dataclass\nclass ModelInfo:\n    \"\"\"The model info dataclass.\n\n    Parameters\n    ----------\n    model : str\n        The identifier of the input model.\n        It may be a compiled model's id (e.g., \"Llama-2-7b-chat-hf-q4f16_1\"),\n        or a full path to a model directory\n        (e.g., \"dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1\")\n\n    model_lib : Optional[str]\n        The path to the compiled library of the model.\n        E.g., \"dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so\"\n    \"\"\"\n\n    model: str\n    model_lib: Optional[str] = None\n\n\ndef _check_engine_config(\n    model: str,\n    model_lib: Optional[str],\n    mode: Literal[\"local\", \"interactive\", \"server\"],\n    engine_config: EngineConfig,\n) -> None:\n    \"\"\"Check if the given engine config is valid.\"\"\"\n    if engine_config.model is not None and engine_config.model != model:\n        raise ValueError(\n            f'The argument \"model\" of engine constructor is \"{model}\", while the \"model\" '\n            f'field in argument \"engine_config\" is \"{engine_config.model}\". '\n            'Please set the \"engine_config.model\" to None or set it to the same as the '\n            'argument \"model\".'\n        )\n    if (\n        engine_config.model_lib is not None\n        and model_lib is not None\n        and engine_config.model_lib != model_lib\n    ):\n        raise ValueError(\n            f'The argument \"model_lib\" of engine constructor is \"{model_lib}\", while the '\n            f'\"model_lib\" field in argument \"engine_config\" is \"{engine_config.model_lib}\". '\n            'Please set the \"engine_config.model_lib\" to None or set it to the same as the '\n            'argument \"model_lib\".'\n        )\n    if engine_config.mode is not None and engine_config.mode != mode:\n        raise ValueError(\n            f'The argument \"mode\" of engine constructor is \"{mode}\", while the '\n            f'\"mode\" field in argument \"engine_config\" is \"{engine_config.mode}\". '\n            'Please set the \"engine_config.mode\" to None or set it to the same as the '\n            'argument \"mode\".'\n        )\n    if engine_config.kv_cache_page_size != 16:\n        raise ValueError(\n            'KV cache only supports page size 16, while the \"kv_cache_page_size\" field in '\n            f'argument \"engine_config\" is \"{engine_config.kv_cache_page_size}\". '\n            'Please set \"engine_config.kv_cache_page_size\" to 16.'\n        )\n\n\ndef _parse_models(\n    model: str,\n    model_lib: Optional[str],\n    additional_models: List[Union[str, Tuple[str, str]]],\n) -> List[ModelInfo]:\n    \"\"\"Parse the specified model paths and model libs.\n    Return a list of ModelInfo, which is a wrapper class of the model path + lib path.\n    \"\"\"\n    models = [ModelInfo(model, model_lib)]\n    for additional_model in additional_models:\n        if isinstance(additional_model, str):\n            models.append(ModelInfo(additional_model))\n        else:\n            models.append(ModelInfo(additional_model[0], additional_model[1]))\n    return models\n\n\ndef _process_model_args(\n    models: List[ModelInfo],\n    device: tvm.runtime.Device,\n    engine_config: EngineConfig,\n) -> Tuple[List[Tuple[str, str]], List[str], Conversation]:\n    \"\"\"Process the input ModelInfo to get the engine initialization arguments.\"\"\"\n    conversation: Optional[Conversation] = None\n    config_file_paths: List[str] = []\n\n    def _convert_model_info(model: ModelInfo) -> Tuple[str, str]:\n        nonlocal conversation\n\n        model_path = download_cache.get_or_download_model(model.model)\n        mlc_config_path = model_path / \"mlc-chat-config.json\"\n        config_file_paths.append(str(mlc_config_path))\n\n        with open(mlc_config_path, mode=\"rt\", encoding=\"utf-8\") as file:\n            mlc_chat_config = MLCChatConfig.model_validate_json(file.read())\n\n        if conversation is None:\n            conversation = mlc_chat_config.conv_template\n\n        if model.model_lib is not None:\n            # do model lib search if the model lib is provided\n            # error out if file not found\n            if model.model_lib.startswith(\"mock://\"):\n                model_lib = model.model_lib\n                logger.info(\"[DEBUG] mock test: %s\", model_lib)\n            elif Path(model.model_lib).is_file():\n                model_lib = model.model_lib\n                logger.info(\"Using library model: %s\", model_lib)\n            else:\n                raise FileNotFoundError(\n                    f\"The `model_lib` you passed in is not a file: {model.model_lib}.\\n\"\n                )\n        else:\n            # Run jit if model_lib is not provided\n            # NOTE: we only import jit when necessary\n            # so the engine do not have to depend on compilation\n            from mlc_llm.interface import jit  # pylint: disable=import-outside-toplevel\n\n            model_compile_overrides = {\n                \"context_window_size\": engine_config.max_single_sequence_length,\n                \"prefill_chunk_size\": engine_config.prefill_chunk_size,\n                \"sliding_window_size\": engine_config.sliding_window_size,\n                \"attention_sink_size\": engine_config.attention_sink_size,\n                \"tensor_parallel_shards\": engine_config.tensor_parallel_shards,\n                \"pipeline_parallel_stages\": engine_config.pipeline_parallel_stages,\n                \"max_batch_size\": engine_config.max_num_sequence,\n                \"opt\": engine_config.opt,\n            }\n\n            model_lib = jit.jit(\n                model_path=model_path,\n                overrides=model_compile_overrides,\n                device=device,\n            ).model_lib_path\n        return str(model_path), model_lib\n\n    model_args: List[Tuple[str, str]] = [_convert_model_info(model) for model in models]\n\n    assert conversation is not None\n    return model_args, config_file_paths, conversation\n\n\ndef _print_engine_mode_logging_msg(\n    mode: Literal[\"local\", \"interactive\", \"server\"],\n) -> None:\n    \"\"\"Print the logging info for engine mode selection.\"\"\"\n    if mode == \"local\":\n        logger.info(\n            \"The selected engine mode is %s. \"\n            \"We choose small max batch size and KV cache capacity to use less GPU memory.\",\n            green(mode),\n        )\n    elif mode == \"interactive\":\n        logger.info(\n            \"The selected engine mode is %s. \"\n            \"We fix max batch size to 1 for interactive single sequence use.\",\n            green(mode),\n        )\n    else:\n        logger.info(\n            \"The selected engine mode is %s. \"\n            \"We use as much GPU memory as possible (within the limit \"\n            \"of gpu_memory_utilization).\",\n            green(mode),\n        )\n\n    if mode != \"local\":\n        logger.info(\n            \"If you have low concurrent requests and want to use less GPU memory, \"\n            'please select mode \"local\".'\n        )\n    if mode != \"interactive\":\n        logger.info(\n            \"If you don't have concurrent requests and only use the engine interactively, \"\n            'please select mode \"interactive\".'\n        )\n    if mode != \"server\":\n        logger.info(\n            \"If you have high concurrent requests and want to maximize the GPU memory utilization, \"\n            'please select mode \"server\".'\n        )\n\n\nclass EngineMetrics:\n    \"\"\"Class to store the result returned by engine metrics\"\"\"\n\n    metrics: dict\n\n    def __init__(self, metrics):\n        self.metrics = metrics\n\n    def __str__(self):\n        return self.metrics.__str__()\n\n    def __repr__(self):\n        return self.metrics.__repr__()\n\n    def __getitem__(self, key):\n        return self.metrics[key]\n\n    def prometheus_text(self) -> str:\n        \"\"\"Convert engine metrics into prometheus text format\n\n        Returns\n        -------\n        text: str\n            The metrics in prometheus text format\n        \"\"\"\n        output_lines = [\n            \"# NOTE: these metrics count token in the unit of serving model's tokenization\",\n            \"# be careful when comparing them to client-side metrics that may use\",\n            \"# different tokenization to standardize across models.\\n\",\n        ]\n\n        def traverse(comment_scope, key_prefix, curr_value):\n            if isinstance(curr_value, dict):\n                if comment_scope:\n                    output_lines.append(f\"\\n# {comment_scope}\")\n                # first prioritize metrics in current scope\n                for key, value in curr_value.items():\n                    if isinstance(value, numbers.Number):\n                        output_lines.append(f\"{key_prefix}{key}\\t{value}\")\n                # then look into nested scopes if any\n                for key, value in curr_value.items():\n                    if isinstance(value, dict) and len(value) != 0:\n                        traverse(f\"{comment_scope}/{key}\", f\"{key_prefix}{key}_\", value)\n\n        traverse(\"\", \"\", self.metrics)\n        return \"\\n\".join(output_lines)\n\n\ndef _query_engine_metrics(engine):\n    \"\"\"Query engine metrics via debug options\"\"\"\n    dummy_message = {\"role\": \"user\", \"context\": \"\"}\n    for response in engine.chat.completions.create(\n        messages=[dummy_message],\n        model=\"model\",\n        stream=True,\n        stream_options={\"include_usage\": True},\n        extra_body={\"debug_config\": {\"special_request\": \"query_engine_metrics\"}},\n    ):\n        if response.usage is not None:\n            return EngineMetrics(response.usage.extra)\n    raise RuntimeError(\"query_engine metrics did not get metrics back\")\n\n\nasync def _async_query_engine_metrics(engine):\n    \"\"\"Query engine metrics via debug options\"\"\"\n    dummy_message = {\"role\": \"user\", \"context\": \"\"}\n    result = None\n    async for response in await engine.chat.completions.create(\n        messages=[dummy_message],\n        model=\"model\",\n        stream=True,\n        stream_options={\"include_usage\": True},\n        extra_body={\"debug_config\": {\"special_request\": \"query_engine_metrics\"}},\n    ):\n        if response.usage is not None:\n            assert result is None\n            result = EngineMetrics(response.usage.extra)\n\n    if result is not None:\n        return result\n    raise RuntimeError(\"query_engine metrics did not get metrics back\")\n\n\n@dataclass\nclass CallbackStreamOutput:\n    \"\"\"The output of MLCEngine._generate and AsyncMLCEngine._generate\n\n    Attributes\n    ----------\n    delta_text : str\n        The delta text generated since the last output.\n\n    delta_logprob_json_strs : Optional[List[str]]\n        The list of logprob JSON strings since the last output,\n        or None if the request does not require logprobs.\n\n    finish_reason : Optional[str]\n        The finish reason of the request, or None if unfinished.\n\n    request_final_usage_json_str: Optional[str]\n        The usage json which appears in last chunk,\n        when it appears all other fields will be empty\n    \"\"\"\n\n    delta_text: str\n    delta_logprob_json_strs: Optional[List[str]]\n    finish_reason: Optional[str]\n    request_final_usage_json_str: Optional[str]\n\n\nclass AsyncRequestStream:\n    \"\"\"The asynchronous stream for requests in AsyncMLCEngine.\n\n    Each request has its own unique stream.\n    The stream exposes the method `push` for engine to push new generated\n    delta text to the stream, and the method `finish` for engine to mark\n    the finish of generation.\n\n    The stream implements `__aiter__` and `__anext__`, which the engine\n    can use to iterates all the generated tokens in order asynchronously.\n    \"\"\"\n\n    # The asynchronous queue to hold elements of either a list of\n    # CallbackStreamOutput or an exception.\n    if sys.version_info >= (3, 9):\n        _queue: asyncio.Queue[  # pylint: disable=unsubscriptable-object\n            Union[List[CallbackStreamOutput], Exception]\n        ]\n    else:\n        _queue: asyncio.Queue\n    # The finish flag.\n    _finished: bool\n\n    def __init__(self) -> None:\n        self._queue = asyncio.Queue()\n        self._finished = False\n\n    def push(self, item_or_exception: Union[List[CallbackStreamOutput], Exception]) -> None:\n        \"\"\"Push a new token to the stream.\"\"\"\n        if self._finished:\n            # No new item is expected after finish.\n            self._queue.put_nowait(\n                RuntimeError(\n                    \"The request has already finished. \"\n                    \"The stream is not supposed to accept new items.\"\n                )\n            )\n            return\n        self._queue.put_nowait(item_or_exception)\n\n    def finish(self) -> None:\n        \"\"\"Mark the finish of the generation in the stream.\"\"\"\n        self._queue.put_nowait(StopIteration())\n        self._finished = True\n\n    def __aiter__(self):\n        return self\n\n    async def __anext__(self) -> List[CallbackStreamOutput]:\n        result = await self._queue.get()\n        if isinstance(result, StopIteration):\n            raise StopAsyncIteration\n        if isinstance(result, Exception):\n            raise result\n        return result\n\n\nclass EngineState:\n    \"\"\"The engine states that the request stream callback function may use.\n\n    This class is used for both AsyncMLCEngine and MLCEngine.\n    AsyncMLCEngine uses the fields and methods starting with \"async\",\n    and MLCEngine uses the ones starting with \"sync\".\n\n    - For AsyncMLCEngine, the state contains an asynchronous event loop,\n    the streamers and the number of unfinished generations for each request\n    being processed.\n    - For MLCEngine, the state contains a callback output blocking queue,\n    the text streamers and the number of unfinished requests.\n\n    We use this state class to avoid the callback function from capturing\n    the AsyncMLCEngine.\n\n    The state also optionally maintains an event trace recorder, which can\n    provide Chrome tracing when enabled.\n    \"\"\"\n\n    trace_recorder = None\n    # States used for AsyncMLCEngine\n    async_event_loop: Optional[asyncio.AbstractEventLoop] = None\n    async_streamers: Dict[str, Tuple[AsyncRequestStream, List[TextStreamer]]] = {}\n    # States used for MLCEngine\n    sync_output_queue: queue.Queue = queue.Queue()\n    sync_text_streamers: List[TextStreamer] = []\n\n    def __init__(self, enable_tracing: bool) -> None:\n        \"\"\"Constructor.\"\"\"\n        if enable_tracing:\n            self.trace_recorder = EventTraceRecorder()\n\n    def record_event(self, request_id: str, event: str) -> None:\n        \"\"\"Record a event for the input request in the trace\n        recorder when the recorder exists.\n\n        Parameters\n        ----------\n        request_id : str\n            The subject request of the event.\n\n        event : str\n            The event in a string name.\n            It can have one of the following patterns:\n            - \"start xxx\", which marks the start of event \"xxx\",\n            - \"finish xxx\", which marks the finish of event \"xxx\",\n            - \"yyy\", which marks the instant event \"yyy\".\n            The \"starts\" and \"finishes\" will be automatically paired in the trace recorder.\n        \"\"\"\n        if self.trace_recorder is None:\n            return\n        self.trace_recorder.add_event(request_id, event)\n\n    def get_request_stream_callback(\n        self, kind: Literal[\"async\", \"sync\"]\n    ) -> Callable[[List[data.RequestStreamOutput]], None]:\n        \"\"\"Construct a callback function and return.\n\n        The callback function has signature\n        \"Callable[[List[data.RequestStreamOutput]], None]\",\n        whose input is a list of \"data.RequestStreamOutput\".\n        Each \"data.RequestStreamOutput\" is the delta output of a request,\n        generated from the engine.\n        \"\"\"\n\n        f_callback = (\n            self._async_request_stream_callback\n            if kind == \"async\"\n            else self._sync_request_stream_callback\n        )\n\n        def _callback(delta_outputs: List[data.RequestStreamOutput]) -> None:\n            f_callback(delta_outputs)\n\n        return _callback\n\n    def async_lazy_init_event_loop(self) -> None:\n        \"\"\"Lazily set the asyncio event loop so that the event\n        loop is the main driving event loop of the process.\n        \"\"\"\n        if self.async_event_loop is None:\n            self.async_event_loop = asyncio.get_event_loop()\n\n    def _async_request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None:\n        \"\"\"The request stream callback function for AsyncMLCEngine to stream back\n        the request generation results.\n\n        Note\n        ----\n        This callback function uses `call_soon_threadsafe` in asyncio to\n        schedule the invocation in the event loop, so that the underlying\n        callback logic will be executed asynchronously in the future rather\n        than right now.\n        \"\"\"\n\n        # Schedule a callback run in the event loop without executing right now.\n        # NOTE: This function causes GIL during execution.\n        self.async_event_loop.call_soon_threadsafe(\n            self._async_request_stream_callback_impl, delta_outputs\n        )\n\n    def _async_request_stream_callback_impl(\n        self, delta_outputs: List[data.RequestStreamOutput]\n    ) -> None:\n        \"\"\"The underlying implementation of request stream callback for AsyncMLCEngine.\"\"\"\n        for delta_output in delta_outputs:\n            request_id, stream_outputs = delta_output.unpack()\n            streamers = self.async_streamers.get(request_id, None)\n            if streamers is None:\n                continue\n\n            self.record_event(request_id, event=\"start callback\")\n            stream, text_streamers = streamers\n\n            # final chunk is now always indicated by a chunk\n            # where usage json is present\n            # the backend engine always streams back this chunk\n            # regardless of include_usage option\n            is_final_chunk = stream_outputs[0].request_final_usage_json_str is not None\n            if is_final_chunk:\n                # stream back this final usage chunk\n                output = CallbackStreamOutput(\n                    delta_text=\"\",\n                    delta_logprob_json_strs=None,\n                    finish_reason=None,\n                    request_final_usage_json_str=stream_outputs[0].request_final_usage_json_str,\n                )\n                stream.push([output])\n                stream.finish()\n                self.async_streamers.pop(request_id, None)\n                continue\n\n            outputs = []\n            for stream_output, text_streamer in zip(stream_outputs, text_streamers):\n                self.record_event(request_id, event=\"start detokenization\")\n                delta_text = stream_output.extra_prefix_string + (\n                    text_streamer.put(stream_output.delta_token_ids)\n                    if len(stream_output.delta_token_ids) > 0\n                    else \"\"\n                )\n                if stream_output.finish_reason is not None:\n                    delta_text += text_streamer.finish()\n                self.record_event(request_id, event=\"finish detokenization\")\n\n                outputs.append(\n                    CallbackStreamOutput(\n                        delta_text=delta_text,\n                        delta_logprob_json_strs=stream_output.delta_logprob_json_strs,\n                        finish_reason=stream_output.finish_reason,\n                        request_final_usage_json_str=None,\n                    )\n                )\n\n            # Push new delta text to the stream.\n            stream.push(outputs)\n            self.record_event(request_id, event=\"finish callback\")\n\n    def _sync_request_stream_callback(self, delta_outputs: List[data.RequestStreamOutput]) -> None:\n        \"\"\"The request stream callback function for MLCEngine to stream back\n        the request generation results.\n        \"\"\"\n        # Put the delta outputs to the queue in the unblocking way.\n        self.sync_output_queue.put_nowait(delta_outputs)\n\n\nclass MLCEngineBase:  # pylint: disable=too-many-instance-attributes,too-few-public-methods\n    \"\"\"The base engine class, which implements common functions that\n    are shared by MLCEngine and AsyncMLCEngine.\n\n    This class wraps a threaded engine that runs on a standalone\n    thread inside and streams back the delta generated results via\n    callback functions. The internal threaded engine keeps running an\n    loop that drives the engine.\n\n    MLCEngine and AsyncMLCEngine inherits this MLCEngineBase class, and implements\n    their own methods to process the delta generated results received\n    from callback functions and yield the processed delta results in\n    the forms of standard API protocols.\n\n    Checkout subclasses AsyncMLCEngine/MLCEngine for the docstring of constructor parameters.\n    \"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        kind: Literal[\"async\", \"sync\"],\n        model: str,\n        device: Union[str, tvm.runtime.Device],\n        model_lib: Optional[str],\n        mode: Literal[\"local\", \"interactive\", \"server\"],\n        engine_config: Optional[EngineConfig],\n        enable_tracing: bool,\n    ) -> None:\n        # - Check the fields fields of `engine_config`.\n        if engine_config is None:\n            engine_config = EngineConfig()\n        _check_engine_config(model, model_lib, mode, engine_config)\n\n        # - Initialize model loading info.\n        models = _parse_models(model, model_lib, engine_config.additional_models)\n        if isinstance(device, str):\n            device = detect_device(device)\n        assert isinstance(device, Device)\n        (\n            model_args,\n            model_config_paths,\n            self.conv_template,\n        ) = _process_model_args(models, device, engine_config)\n\n        # - Load the raw model config into dict\n        self.model_config_dicts = []\n        for i, model_info in enumerate(models):\n            model_info.model_lib = model_args[i][1]\n            with open(model_config_paths[i], \"r\", encoding=\"utf-8\") as file:\n                self.model_config_dicts.append(json.load(file))\n\n        # - Print logging info for regarding the mode selection.\n        if engine_config.verbose:\n            _print_engine_mode_logging_msg(mode)\n\n        # - Initialize engine state and engine.\n        self.state = EngineState(enable_tracing)\n        module = tvm.get_global_func(\"mlc.serve.create_threaded_engine\", allow_missing=False)()\n        self._ffi = {\n            key: module[key]\n            for key in [\n                \"add_request\",\n                \"abort_request\",\n                \"run_background_loop\",\n                \"run_background_stream_back_loop\",\n                \"reload\",\n                \"init_threaded_engine\",\n                \"exit_background_loop\",\n                \"create_request\",\n                \"get_complete_engine_config\",\n                \"reset\",\n                \"debug_call_func_on_all_worker\",\n            ]\n        }\n        self.tokenizer = Tokenizer(model_args[0][0])\n        self._ffi[\"init_threaded_engine\"](\n            device,\n            self.state.get_request_stream_callback(kind),\n            self.state.trace_recorder,\n        )\n\n        background_loop = self._ffi[\"run_background_loop\"]\n        background_stream_back_loop = self._ffi[\"run_background_stream_back_loop\"]\n\n        # - Create the background engine-driving thread and start the loop.\n        self._background_loop_thread: threading.Thread = threading.Thread(target=background_loop)\n        self._background_stream_back_loop_thread: threading.Thread = threading.Thread(\n            target=background_stream_back_loop\n        )\n        self._background_loop_thread.start()\n        self._background_stream_back_loop_thread.start()\n        self._terminated = False\n\n        engine_config.model = model_args[0][0]\n        engine_config.model_lib = model_args[0][1]\n        engine_config.additional_models = model_args[1:]  # type: ignore\n        engine_config.mode = mode\n        self._ffi[\"reload\"](engine_config.asjson())\n        self.engine_config = EngineConfig.from_json(self._ffi[\"get_complete_engine_config\"]())\n        self.max_input_sequence_length = min(\n            self.engine_config.max_single_sequence_length,\n            self.engine_config.max_total_sequence_length,\n        )\n\n    def __del__(self):\n        \"\"\"deleter, auto terminate\"\"\"\n        self.terminate()\n\n    def terminate(self):\n        \"\"\"Terminate the engine.\"\"\"\n        if hasattr(self, \"_terminated\") and self._terminated:\n            return\n        self._terminated = True\n        if not hasattr(self, \"_ffi\"):\n            return\n        self._ffi[\"exit_background_loop\"]()\n        if hasattr(self, \"_background_loop_thread\"):\n            self._background_loop_thread.join()\n        if hasattr(self, \"_background_stream_back_loop_thread\"):\n            self._background_stream_back_loop_thread.join()\n\n    def _debug_call_func_on_all_worker(\n        self, func_name: str, func_args: Optional[str] = None\n    ) -> None:\n        \"\"\"Call the given global function on all workers. Only for debug purpose.\"\"\"\n        self._ffi[\"debug_call_func_on_all_worker\"](func_name, func_args)\n\n    def reset(self):\n        \"\"\"Reset the engine, clear the running data and metrics.\"\"\"\n        return self._ffi[\"reset\"]()\n\n\ndef process_chat_completion_request(  # pylint: disable=too-many-arguments\n    request: openai_api_protocol.ChatCompletionRequest,\n    request_id: str,\n    engine_state: EngineState,\n    model_config: Dict[str, Any],\n    f_tokenize: Callable[[str], List[int]],\n    max_input_sequence_length: int,\n    conv_template: Conversation,\n) -> Tuple[List[Union[List[int], data.Data]], GenerationConfig, bool, int]:\n    \"\"\"Process the given ChatCompletionRequest, apply request validity\n    checks, and return the processed prompts, and other info.\n\n    Parameters\n    ----------\n    request : openai_api_protocol.ChatCompletionRequest\n        The request to be processed and checked.\n\n    request_id : str\n        The id of the request.\n\n    engine_state : EngineState\n        The state of the engine.\n\n    model_config : Dict[str, Any]\n        The model configuration dictionary.\n\n    f_tokenize : Callable[[str], List[int]]\n        The tokenizer encode function.\n\n    max_input_sequence_length : int\n        The maximum allowed total prompt length.\n\n    conv_template : Conversation\n        The conversation template of the model.\n\n    Returns\n    -------\n    prompts : List[Union[List[int], data.Data]]\n        The prompts, in a list.\n        Each element is a list of token ids or a \"data.Data\" instance.\n\n    generation_cfg : GenerationConfig\n        The generation config of the request got from the input request.\n\n    use_function_calling : bool\n        A boolean flag indicating if the request uses function call.\n\n    prompt_length : int\n        The total prompt length.\n    \"\"\"\n    engine_state.record_event(request_id, event=\"receive request\")\n    # - Check if unsupported arguments are specified.\n    engine_utils.check_unsupported_fields(request)\n\n    # - Process messages and update the conversation template in three steps:\n    #   i. Check the message validity.\n    #  ii. Add the input messages to the conversation template.\n    # iii. Add the additional message for the assistant.\n    request.check_message_validity()\n    # - Check for function calling usage and update the conversation template\n    request.check_function_call_usage(conv_template)\n\n    for message in request.messages:\n        role = message.role\n        content = message.content\n        if role == \"system\":\n            assert isinstance(content, str)\n            conv_template.system_message = content if content is not None else \"\"\n            continue\n        conv_template.messages.append((role, content))\n    conv_template.messages.append((\"assistant\", None))\n\n    # - Get the prompt from template, and encode to token ids.\n    # - Check prompt length\n    engine_state.record_event(request_id, event=\"start tokenization\")\n    prompts = engine_utils.process_prompts(  # type: ignore\n        conv_template.as_prompt(model_config), f_tokenize\n    )\n    engine_state.record_event(request_id, event=\"finish tokenization\")\n\n    if conv_template.system_prefix_token_ids is not None:\n        if isinstance(prompts[0], list):\n            prompts[0] = conv_template.system_prefix_token_ids + prompts[0]\n        else:\n            prompts.insert(0, conv_template.system_prefix_token_ids)\n    prompt_length = engine_utils.check_and_get_prompts_length(prompts, max_input_sequence_length)\n\n    # Process generation config. Create request id.\n    generation_cfg = engine_utils.get_generation_config(\n        request,\n        extra_stop_token_ids=conv_template.stop_token_ids,\n        extra_stop_str=conv_template.stop_str,\n    )\n    return prompts, generation_cfg, conv_template.use_function_calling, prompt_length\n\n\ndef process_chat_completion_stream_output(  # pylint: disable=too-many-arguments\n    delta_outputs: List[CallbackStreamOutput],\n    request: openai_api_protocol.ChatCompletionRequest,\n    request_id: str,\n    engine_state: EngineState,\n    use_function_calling: bool,\n    finish_reasons: List[Optional[str]],\n) -> Optional[openai_api_protocol.ChatCompletionStreamResponse]:\n    \"\"\"Process the delta outputs of a single request of ChatCompletion,\n    convert the delta output to ChatCompletionStreamResponse and return.\n\n    Parameters\n    ----------\n    delta_outputs : List[CallbackStreamOutput]\n        The delta outputs of a request.\n        The list length is the number of parallel generation specified by \"n\".\n        Each element corresponds to a generation.\n\n    request_id : str\n        The id of the request.\n\n    engine_state : EngineState\n        The state of the engine.\n\n    use_function_calling : bool\n        A boolean flag indicating if the request uses function call.\n\n    finish_reasons : List[Optional[str]]\n        The list of finish reasons of each generation.\n        The list length is the number of parallel generation specified by \"n\".\n        This list is updated in place.\n\n    Returns\n    -------\n    response : Optional[openai_api_protocol.ChatCompletionStreamResponse]\n        The converted OpenAI API ChatCompletionStreamResponse instance.\n        It can be none when there is no content.\n    \"\"\"\n    # we always stream back the final chunk with usage\n    is_final_chunk = delta_outputs[0].request_final_usage_json_str is not None\n    if is_final_chunk:\n        assert len(delta_outputs) == 1\n        engine_state.record_event(request_id, event=\"yield final usage\")\n        response = openai_api_protocol.ChatCompletionStreamResponse(\n            id=request_id,\n            choices=[],\n            model=request.model,\n            system_fingerprint=\"\",\n            usage=openai_api_protocol.CompletionUsage.model_validate_json(\n                delta_outputs[0].request_final_usage_json_str\n            ),\n        )\n        # non streaming mode always comes with usage\n        if not request.stream:\n            return response\n        # skip usage if stream option does not indicate include usage\n        if request.stream_options is None:\n            return None\n        if not request.stream_options.include_usage:\n            return None\n        return response\n\n    # normal chunk\n    assert len(delta_outputs) == request.n\n    choices = []\n    for i, delta_output in enumerate(delta_outputs):\n        finish_reason_updated = False\n        if delta_output.finish_reason is not None and finish_reasons[i] is None:\n            finish_reasons[i] = (\n                delta_output.finish_reason if not use_function_calling else \"tool_calls\"\n            )\n            finish_reason_updated = True\n        if not finish_reason_updated and delta_output.delta_text == \"\":\n            # Ignore empty delta text when finish reason is not updated.\n            engine_state.record_event(request_id, event=\"skip empty delta text\")\n            continue\n\n        choices.append(\n            openai_api_protocol.ChatCompletionStreamResponseChoice(\n                index=i,\n                finish_reason=finish_reasons[i],\n                delta=openai_api_protocol.ChatCompletionMessage(\n                    content=delta_output.delta_text, role=\"assistant\"\n                ),\n                logprobs=(\n                    openai_api_protocol.LogProbs(\n                        content=[\n                            openai_api_protocol.LogProbsContent.model_validate_json(\n                                logprob_json_str\n                            )\n                            for logprob_json_str in delta_output.delta_logprob_json_strs\n                        ]\n                    )\n                    if delta_output.delta_logprob_json_strs is not None\n                    else None\n                ),\n            )\n        )\n\n    if len(choices) == 0:\n        # Skip return when there is no delta output and no number of completion tokens.\n        return None\n    response = openai_api_protocol.ChatCompletionStreamResponse(\n        id=request_id, choices=choices, model=request.model, system_fingerprint=\"\"\n    )\n    engine_state.record_event(request_id, event=\"yield delta output\")\n    return response\n\n\ndef process_completion_request(  # pylint: disable=too-many-arguments\n    request: openai_api_protocol.CompletionRequest,\n    request_id: str,\n    engine_state: EngineState,\n    tokenizer: Tokenizer,\n    max_input_sequence_length: int,\n    conv_template: Conversation,\n) -> Tuple[List[int], GenerationConfig, int, Optional[openai_api_protocol.CompletionResponse]]:\n    \"\"\"Process the given CompletionRequest, apply request validity\n    checks, and return the processed prompts, and other info.\n\n    Parameters\n    ----------\n    request : openai_api_protocol.CompletionRequest\n        The request to be processed and checked.\n\n    request_id : str\n        The id of the request.\n\n    engine_state : EngineState\n        The state of the engine.\n\n    tokenizer : Tokenizer\n        The tokenizer instance of the model.\n\n    max_input_sequence_length : int\n        The maximum allowed total prompt length.\n\n    conv_template : Conversation\n        The conversation template of the model.\n\n    Returns\n    -------\n    prompt : List[int]\n        The prompt in a list of token ids.\n\n    generation_cfg : GenerationConfig\n        The generation config of the request got from the input request.\n\n    prompt_length : int\n        The total prompt length.\n\n    echo_response : Optional[openai_api_protocol.CompletionResponse]\n        The CompletionResponse of the echoing part, when argument \"echo\"\n        of the input request is specified.\n    \"\"\"\n    engine_state.record_event(request_id, event=\"receive request\")\n    # - Check if unsupported arguments are specified.\n    engine_utils.check_unsupported_fields(request)\n\n    # - Process prompt and check validity.\n    engine_state.record_event(request_id, event=\"start tokenization\")\n    prompts = engine_utils.process_prompts(request.prompt, tokenizer.encode)\n    engine_state.record_event(request_id, event=\"finish tokenization\")\n    prompt_length = engine_utils.check_and_get_prompts_length(prompts, max_input_sequence_length)\n    prompt = prompts[0]\n    assert isinstance(prompt, list)\n\n    # Process generation config. Create request id.\n    generation_cfg = engine_utils.get_generation_config(\n        request,\n        extra_stop_token_ids=conv_template.stop_token_ids,\n        extra_stop_str=conv_template.stop_str,\n    )\n\n    # - Echo back the prompt.\n    echo_response = None\n    if request.echo:\n        text = tokenizer.decode(prompt)\n        response = openai_api_protocol.CompletionResponse(\n            id=request_id,\n            choices=[\n                openai_api_protocol.CompletionResponseChoice(index=i, text=text)\n                for i in range(generation_cfg.n)\n            ],\n            model=request.model,\n            usage=None,\n        )\n        echo_response = response\n    return prompt, generation_cfg, prompt_length, echo_response\n\n\ndef get_logprobs_from_delta(\n    delta_logprob_json_strs: List[str],\n) -> openai_api_protocol.CompletionLogProbs:\n    \"\"\"Convert json strings containing logprobs information to\n    completion response format (OpenAI API compatible)\n\n    Parameters\n    ----------\n    delta_logprob_json_strs : List[str]\n        Logprobs information packed in json strings and\n        kept in the delta outputs of a request.\n\n    Returns\n    -------\n    logprobs : openai_api_protocol.CompletionLogProbs\n        Logprobs information extracted from json string and converted to completion response format\n    \"\"\"\n    token_logprobs = []\n    tokens = []\n    top_logprobs = []\n    for logprob_json_str in delta_logprob_json_strs:\n        content = openai_api_protocol.LogProbsContent.model_validate_json(logprob_json_str)\n        tokens.append(content.token)\n        token_logprobs.append(content.logprob)\n        top_logprob_dict = {}\n        for top_logprob in content.top_logprobs:\n            top_logprob_dict[top_logprob.token] = top_logprob.logprob\n        top_logprobs.append(top_logprob_dict)\n    return openai_api_protocol.CompletionLogProbs(\n        # TODO(vvchernov): support text_offset\n        text_offset=None,\n        token_logprobs=token_logprobs,\n        tokens=tokens,\n        top_logprobs=top_logprobs,\n    )\n\n\ndef process_completion_stream_output(  # pylint: disable=too-many-arguments\n    delta_outputs: List[CallbackStreamOutput],\n    request: openai_api_protocol.CompletionRequest,\n    request_id: str,\n    engine_state: EngineState,\n    finish_reasons: List[Optional[str]],\n) -> Optional[openai_api_protocol.CompletionResponse]:\n    \"\"\"Process the delta outputs of a single request of Completion,\n    convert the delta output to CompletionResponse and return.\n\n    Parameters\n    ----------\n    delta_outputs : List[CallbackStreamOutput]\n        The delta outputs of a request.\n        The list length is the number of parallel generation specified by \"n\".\n        Each element corresponds to a generation.\n\n    request: openai_api_protocol.CompletionRequest\n        Information about the request\n\n    request_id : str\n        The id of the request.\n\n    engine_state : EngineState\n        The state of the engine.\n\n    finish_reasons : List[Optional[str]]\n        The list of finish reasons of each generation.\n        The list length is the number of parallel generation specified by \"n\".\n        This list is updated in place.\n\n    Returns\n    -------\n    response : Optional[openai_api_protocol.CompletionResponse]\n        The converted OpenAI API CompletionResponse instance.\n        It can be none when there is no content.\n    \"\"\"\n    # we always stream back the final chunk with usage\n    is_final_chunk = delta_outputs[0].request_final_usage_json_str is not None\n    if is_final_chunk:\n        assert len(delta_outputs) == 1\n        engine_state.record_event(request_id, event=\"yield final usage\")\n        response = openai_api_protocol.CompletionResponse(\n            id=request_id,\n            choices=[],\n            model=request.model,\n            system_fingerprint=\"\",\n            usage=openai_api_protocol.CompletionUsage.model_validate_json(\n                delta_outputs[0].request_final_usage_json_str\n            ),\n        )\n        # non streaming mode always comes with usage\n        if not request.stream:\n            return response\n        if request.stream_options is None:\n            return None\n        if not request.stream_options.include_usage:\n            return None\n        return response\n\n    # normal chunk\n    assert len(delta_outputs) == request.n\n    choices = []\n    for i, delta_output in enumerate(delta_outputs):\n        finish_reason_updated = False\n        if delta_output.finish_reason is not None and finish_reasons[i] is None:\n            finish_reasons[i] = delta_output.finish_reason\n            finish_reason_updated = True\n        if not finish_reason_updated and delta_output.delta_text == \"\":\n            # Ignore empty delta text when finish reason is not updated.\n            continue\n\n        if delta_output.delta_logprob_json_strs is not None:\n            logprobs = get_logprobs_from_delta(delta_output.delta_logprob_json_strs)\n        else:\n            logprobs = None\n        choices.append(\n            openai_api_protocol.CompletionResponseChoice(\n                index=i,\n                finish_reason=finish_reasons[i],\n                text=delta_output.delta_text,\n                logprobs=logprobs,\n            )\n        )\n\n    if len(choices) == 0:\n        # Skip return when there is no delta output and no number of completion tokens.\n        return None\n    response = openai_api_protocol.CompletionResponse(\n        id=request_id,\n        choices=choices,\n        model=request.model,\n        usage=None,\n    )\n    engine_state.record_event(request_id, event=\"yield delta output\")\n    return response\n\n\ndef create_completion_suffix_response(\n    request: openai_api_protocol.CompletionRequest,\n    request_id: str,\n    finish_reasons: List[Optional[str]],\n) -> Optional[openai_api_protocol.CompletionResponse]:\n    \"\"\"Create the suffix response of Completion request\n    when the request requires suffix.\n\n    Parameters\n    ----------\n    request : openai_api_protocol.CompletionRequest\n        The request whose suffix response if to be created.\n\n    request_id : str\n        The id of the request.\n\n    finish_reasons : List[Optional[str]]\n        The list of finish reasons of each generation.\n        The list length is the number of parallel generation specified by \"n\".\n        This list is updated in place.\n\n    Returns\n    -------\n    suffix_response : Optional[openai_api_protocol.CompletionResponse]\n        The created OpenAI API CompletionResponse instance for the suffix.\n        Or None if the request does not require suffix.\n    \"\"\"\n    # - Echo the suffix.\n    if request.suffix is None:\n        return None\n    assert all(finish_reason is not None for finish_reason in finish_reasons)\n    response = openai_api_protocol.CompletionResponse(\n        id=request_id,\n        choices=[\n            openai_api_protocol.CompletionResponseChoice(\n                index=i,\n                finish_reason=finish_reason,\n                text=request.suffix,\n            )\n            for i, finish_reason in enumerate(finish_reasons)\n        ],\n        model=request.model,\n        usage=None,\n    )\n    return response\n\n\ndef convert_function_str_to_json(stringified_calls: str) -> List[Union[Dict, None]]:\n    \"\"\"Convert a (possibly list) of function call string to a list of json objects.\n    Return None for invalid function call string.\"\"\"\n\n    def parse_function_call(call_str: str):\n        node = ast.parse(call_str, mode=\"eval\")\n        call_node = node.body\n        if isinstance(call_node, ast.Call) and isinstance(call_node.func, ast.Name):\n            name = call_node.func.id\n            arguments = {}\n            for keyword in call_node.keywords:\n                arguments[keyword.arg] = ast.literal_eval(keyword.value)\n            return {\"name\": name, \"arguments\": arguments}\n        return None\n\n    if (\n        stringified_calls[0] == \"[\" and stringified_calls[-1] == \"]\"\n    ):  # hacky way to check if string list\n        calls = ast.literal_eval(stringified_calls)\n    else:\n        calls = [stringified_calls]\n    function_calls_json = [parse_function_call(call_str) for call_str in calls]\n    return function_calls_json\n\n\ndef process_function_call_output(\n    output_texts: List[str], finish_reasons: List[str]\n) -> Tuple[bool, List[List[openai_api_protocol.ChatToolCall]]]:\n    \"\"\"Process the potential function call results outputted by model,\n    according to the finish reasons.\n    Return whether the output has function call, and the list of tool calls.\n    \"\"\"\n    n = len(output_texts)\n    tool_calls_list: List[List[openai_api_protocol.ChatToolCall]] = [[] for _ in range(n)]\n    use_function_calling = any(finish_reason == \"tool_calls\" for finish_reason in finish_reasons)\n    if use_function_calling:\n        for i, output_text in enumerate(output_texts):\n            try:\n                fn_json_list = convert_function_str_to_json(output_text)\n            except (SyntaxError, ValueError):\n                output_text = \"Got an invalid function call output from model\"\n                finish_reasons[i] = \"error\"\n            else:\n                tool_calls_list[i] = [\n                    openai_api_protocol.ChatToolCall(\n                        type=\"function\",\n                        function=openai_api_protocol.ChatFunctionCall(\n                            name=fn_json_obj[\"name\"], arguments=fn_json_obj[\"arguments\"]\n                        ),\n                    )\n                    for fn_json_obj in fn_json_list\n                    if fn_json_obj is not None\n                ]\n                if len(tool_calls_list[i]) == 0:\n                    output_texts[i] = \"Got an invalid function call output from model\"\n                    finish_reasons[i] = \"error\"\n                else:\n                    finish_reasons[i] = \"tool_calls\"\n    return use_function_calling, tool_calls_list\n\n\ndef wrap_chat_completion_response(  # pylint: disable=too-many-arguments\n    request_id: str,\n    model: str,\n    output_texts: List[str],\n    finish_reasons: List[str],\n    tool_calls_list: List[List[openai_api_protocol.ChatToolCall]],\n    logprob_results: Optional[List[List[openai_api_protocol.LogProbsContent]]],\n    use_function_calling: bool,\n    usage: Optional[Dict[str, Any]],\n) -> openai_api_protocol.ChatCompletionResponse:\n    \"\"\"Wrap the non-streaming chat completion results to ChatCompletionResponse instance.\"\"\"\n    return openai_api_protocol.ChatCompletionResponse(\n        id=request_id,\n        choices=[\n            openai_api_protocol.ChatCompletionResponseChoice(\n                index=i,\n                finish_reason=finish_reasons[i],\n                message=(\n                    openai_api_protocol.ChatCompletionMessage(role=\"assistant\", content=output_text)\n                    if not use_function_calling or finish_reason == \"error\"\n                    else openai_api_protocol.ChatCompletionMessage(\n                        role=\"assistant\", tool_calls=tool_calls\n                    )\n                ),\n                logprobs=(\n                    openai_api_protocol.LogProbs(content=logprob_results[i])\n                    if logprob_results is not None\n                    else None\n                ),\n            )\n            for i, (output_text, finish_reason, tool_calls) in enumerate(\n                zip(output_texts, finish_reasons, tool_calls_list)\n            )\n        ],\n        model=model,\n        system_fingerprint=\"\",\n        usage=usage,\n    )\n\n\ndef wrap_completion_response(  # pylint: disable=too-many-arguments\n    request_id: str,\n    model: str,\n    output_texts: List[str],\n    finish_reasons: List[str],\n    logprob_results: List[Optional[openai_api_protocol.CompletionLogProbs]],\n    usage: openai_api_protocol.CompletionUsage,\n) -> openai_api_protocol.CompletionResponse:\n    \"\"\"Wrap the non-streaming completion results to CompletionResponse instance.\"\"\"\n    return openai_api_protocol.CompletionResponse(\n        id=request_id,\n        choices=[\n            openai_api_protocol.CompletionResponseChoice(\n                index=i,\n                finish_reason=finish_reason,\n                text=output_text,\n                logprobs=logprob_results[i],\n            )\n            for i, (output_text, finish_reason) in enumerate(zip(output_texts, finish_reasons))\n        ],\n        model=model,\n        usage=usage,\n    )\n"
  },
  {
    "path": "python/mlc_llm/serve/engine_utils.py",
    "content": "\"\"\"Utility functions for MLC Serve engine\"\"\"\n\nimport uuid\nfrom typing import Any, Callable, Dict, List, Literal, Optional, Union\n\nfrom mlc_llm.protocol import error_protocol, openai_api_protocol\nfrom mlc_llm.protocol.generation_config import GenerationConfig\nfrom mlc_llm.serve import data\n\nRequestProtocol = Union[\n    openai_api_protocol.CompletionRequest, openai_api_protocol.ChatCompletionRequest\n]\n\n\ndef get_unsupported_fields(request: RequestProtocol) -> List[str]:\n    \"\"\"Get the unsupported fields of the request.\n    Return the list of unsupported field names.\n    \"\"\"\n    if isinstance(\n        request,\n        (\n            openai_api_protocol.CompletionRequest,\n            openai_api_protocol.ChatCompletionRequest,\n        ),\n    ):\n        return openai_api_protocol.openai_api_get_unsupported_fields(request)\n    raise RuntimeError(\"Cannot reach here\")\n\n\ndef openai_api_get_generation_config(request: RequestProtocol) -> Dict[str, Any]:\n    \"\"\"Create the generation config from the given request.\"\"\"\n    kwargs: Dict[str, Any] = {}\n    arg_names = [\n        \"n\",\n        \"temperature\",\n        \"top_p\",\n        \"max_tokens\",\n        \"frequency_penalty\",\n        \"presence_penalty\",\n        \"logit_bias\",\n        \"seed\",\n        \"response_format\",\n        \"debug_config\",\n    ]\n    for arg_name in arg_names:\n        kwargs[arg_name] = getattr(request, arg_name)\n    if kwargs[\"max_tokens\"] is None:\n        # Setting to -1 means the generation will not stop until\n        # exceeding model capability or hit any stop criteria.\n        kwargs[\"max_tokens\"] = -1\n    if request.stop is not None:\n        kwargs[\"stop_strs\"] = [request.stop] if isinstance(request.stop, str) else request.stop\n    if isinstance(request, openai_api_protocol.ChatCompletionRequest):\n        kwargs[\"logprobs\"] = request.logprobs\n        kwargs[\"top_logprobs\"] = request.top_logprobs\n    else:\n        logprobs = request.logprobs is not None\n        kwargs[\"logprobs\"] = logprobs\n        kwargs[\"top_logprobs\"] = request.logprobs if logprobs else 0\n    return kwargs\n\n\ndef get_generation_config(\n    request: RequestProtocol,\n    extra_stop_token_ids: Optional[List[int]] = None,\n    extra_stop_str: Optional[List[str]] = None,\n) -> GenerationConfig:\n    \"\"\"Create the generation config in MLC LLM out from the input request protocol.\"\"\"\n    kwargs: Dict[str, Any]\n    if isinstance(\n        request,\n        (\n            openai_api_protocol.CompletionRequest,\n            openai_api_protocol.ChatCompletionRequest,\n        ),\n    ):\n        kwargs = openai_api_get_generation_config(request)\n    else:\n        raise RuntimeError(\"Cannot reach here\")\n\n    if extra_stop_token_ids is not None:\n        stop_token_ids = kwargs.get(\"stop_token_ids\", [])\n        assert isinstance(stop_token_ids, list)\n        stop_token_ids += extra_stop_token_ids\n        kwargs[\"stop_token_ids\"] = stop_token_ids\n\n    if extra_stop_str is not None:\n        stop_strs = kwargs.get(\"stop_strs\", [])\n        assert isinstance(stop_strs, list)\n        stop_strs += extra_stop_str\n        kwargs[\"stop_strs\"] = stop_strs\n\n    return GenerationConfig(**kwargs)\n\n\ndef random_uuid() -> str:\n    \"\"\"Generate a random id in hexadecimal string.\"\"\"\n    return uuid.uuid4().hex\n\n\ndef check_unsupported_fields(request: RequestProtocol) -> None:\n    \"\"\"Check if the request has unsupported fields. Raise BadRequestError if so.\"\"\"\n    unsupported_fields = get_unsupported_fields(request)\n    if len(unsupported_fields) != 0:\n        unsupported_fields = [f'\"{field}\"' for field in unsupported_fields]\n        raise error_protocol.BadRequestError(\n            f\"Request fields {', '.join(unsupported_fields)} are not supported right now.\",\n        )\n\n\ndef check_and_get_prompts_length(\n    prompts: List[Union[List[int], data.ImageData]], max_input_sequence_length: int\n) -> int:\n    \"\"\"Check if the total prompt length exceeds the max single sequence\n    sequence length allowed by the served model. Raise BadRequestError if so.\n    Return the total prompt length.\n    \"\"\"\n    total_length: int = 0\n    for prompt in prompts:\n        total_length += len(prompt)\n    if total_length > max_input_sequence_length:\n        raise error_protocol.BadRequestError(\n            f\"Request prompt has {total_length} tokens in total,\"\n            f\" larger than the model input length limit {max_input_sequence_length}.\",\n        )\n    return total_length\n\n\ndef process_prompts(\n    input_prompts: Union[str, List[int], List[Union[str, List[int], data.ImageData]]],\n    ftokenize: Callable[[str], List[int]],\n) -> List[Union[List[int], data.ImageData]]:\n    \"\"\"Convert all input tokens to list of token ids with regard to the\n    given tokenization function.\n    For each input prompt, return the list of token ids after tokenization.\n    \"\"\"\n    error_msg = f\"Invalid request prompt {input_prompts}\"\n\n    # Case 1. The prompt is a single string.\n    if isinstance(input_prompts, str):\n        return [ftokenize(input_prompts)]\n\n    assert isinstance(input_prompts, list)\n    if len(input_prompts) == 0:\n        raise error_protocol.BadRequestError(error_msg)\n\n    # Case 2. The prompt is a list of token ids.\n    if isinstance(input_prompts[0], int):\n        assert isinstance(input_prompts, list)\n        if not all(isinstance(token_id, int) for token_id in input_prompts):\n            raise error_protocol.BadRequestError(error_msg)\n        return [input_prompts]  # type: ignore\n\n    # Case 3. A list of prompts.\n    output_prompts: List[Union[List[int], data.ImageData]] = []\n    for input_prompt in input_prompts:\n        if isinstance(input_prompt, str):\n            output_prompts.append(ftokenize(input_prompt))\n        elif isinstance(input_prompt, list) and all(\n            isinstance(token_id, int) for token_id in input_prompt\n        ):\n            output_prompts.append(input_prompt)\n        elif isinstance(input_prompt, data.ImageData):\n            output_prompts.append(input_prompt)\n        else:\n            raise error_protocol.BadRequestError(error_msg)\n    return output_prompts\n\n\ndef convert_prompts_to_data(\n    prompts: Union[str, List[int], List[Union[str, List[int], data.Data]]],\n) -> List[data.Data]:\n    \"\"\"Convert the given prompts in the combination of token id lists\n    and/or data to all data.\"\"\"\n    if isinstance(prompts, data.Data):\n        return [prompts]\n    if isinstance(prompts, str):\n        return [data.TextData(prompts)]\n    if isinstance(prompts[0], int):\n        assert isinstance(prompts, list) and all(isinstance(token_id, int) for token_id in prompts)\n        return [data.TokenData(prompts)]  # type: ignore\n    return [convert_prompts_to_data(x)[0] for x in prompts]  # type: ignore\n\n\nclass ErrorCleanupScope:\n    \"\"\"Scope to call cleanup when an error is thrown.\n\n    This class provides an important pattern properly cleanup\n    when async scope CancelledError or other exception happens.\n\n    Parameters\n    ----------\n    cleanup : Callable\n        A callable function to trigger at scope exit during an exception.\n\n    Note\n    ----\n    This helper is motivated by the need to properly\n    abort an async generator and trigger corresponding\n    cleanup functions. Naively use the try except\n    pattern will results in bug when we chain up\n    async generators.\n\n    .. code:: python\n\n        class EngineNotSafe:\n            async def _inner_gen(self, request):\n                request_id = self.get_request_id()\n                self.add_request(request)\n                try:\n                    async for res in await producer_stream:\n                        yield res\n                except asyncio.CancelledError:\n                    self.abort(request_id)\n\n            async def generate(self, request):\n                async for res in await self._inner_gen(request):\n                    # async error can he raised in here\n                    # this will cause\n                    res = await process(res)\n                    yield res\n\n    The above except pattern is not safe.\n    This is because CancelledError may also be raised\n    outside _inner_gen during the process of generate\n    function in between iterations.\n\n    Instead, we use ErrorCleanupScope to safeguard the\n    generation process. The scope will always properly\n    cleanup in exit function when the exception is raised\n\n     .. code:: python\n\n        class EngineSafe:\n            async def _inner_gen(self, request):\n                request_id = self.get_request_id()\n                self.add_request(request)\n                with ErrorCleanupScope(lambda: self.abort(request_id))\n                    async for res in await producer_stream:\n                        yield res\n\n            async def generate(self, request):\n                async for res in await self._inner_gen(request):\n                    # even if async error is raised here\n                    # it will cleanup the ErrorCleanupScope\n                    # properly during function exit\n                    res = await process(res)\n                    yield res\n    \"\"\"\n\n    cleanup: Callable\n\n    def __init__(self, cleanup: Callable):\n        self.cleanup = cleanup\n\n    def __enter__(self):\n        pass\n\n    def __exit__(self, exc_type, exc_value, traceback) -> None:\n        # only cleanup when exc type is not none\n        if exc_type is not None:\n            self.cleanup()\n\n\n# ====== Embedding Engine Utilities ======\n\n\ndef load_embedding_params(model_weight_path, device, model_metadata) -> list:\n    \"\"\"Load embedding model parameters from weight directory.\n\n    Parameters\n    ----------\n    model_weight_path : str\n        Path to the model weight directory.\n    device : tvm.runtime.Device\n        The target device.\n    model_metadata : dict\n        The model metadata dictionary containing param info.\n\n    Returns\n    -------\n    params : list\n        List of tvm.runtime.Tensor parameters in metadata order.\n    \"\"\"\n    from tvm.contrib import tvmjs  # pylint: disable=import-outside-toplevel\n\n    params, meta = tvmjs.load_tensor_cache(model_weight_path, device)\n    param_names = [param[\"name\"] for param in model_metadata[\"params\"]]\n    assert len(param_names) == meta[\"ParamSize\"]\n    return [params[name] for name in param_names]\n\n\ndef get_embedding_metadata(config: Dict[str, Any]) -> Optional[Dict[str, Any]]:\n    \"\"\"Read emedding metadata from mlc-chat-config or model lib metadata.\n\n    Parameters\n    ----------\n    config : Dict[str, Any]\n        The configuration dictionary containing model metadata.\n\n    Returns\n    -------\n    embedding_metadata : Optional[Dict[str, Any]] = None if it's not an embedding model.\n        The embedding metadata dictionary.\n    \"\"\"\n    if config.get(\"model_task\") == \"embedding\":\n        return config.get(\"embedding_metadata\")\n    return None\n\n\ndef detect_embedding_model_type(mod) -> Literal[\"encoder\", \"decoder\"]:\n    \"\"\"Detect embedding model type from compiled TVM module functions.\n\n    Parameters\n    ----------\n    mod : tvm.runtime.Module\n        The VM module with model functions.\n\n    Returns\n    -------\n    model_type : str\n        \"encoder\" for BERT-style models, \"decoder\" for Qwen3-Embeddings style.\n    \"\"\"\n    has_embed = mod.implements_function(\"embed\")\n    has_prefill_to_hidden = mod.implements_function(\"prefill_to_last_hidden_states\")\n    has_prefill = mod.implements_function(\"prefill\")\n\n    if has_embed and has_prefill_to_hidden:\n        return \"decoder\"\n    if has_prefill:\n        return \"encoder\"\n    raise ValueError(\n        \"Model does not support embedding inference. \"\n        \"Expected 'embed' + 'prefill_to_last_hidden_states' (decoder) \"\n        \"or 'prefill' (encoder).\"\n    )\n"
  },
  {
    "path": "python/mlc_llm/serve/entrypoints/__init__.py",
    "content": "\"\"\"The entrypoints for MLC LLM server.\"\"\"\n\nfrom . import (\n    debug_entrypoints,\n    metrics_entrypoints,\n    microserving_entrypoints,\n    openai_entrypoints,\n)\n"
  },
  {
    "path": "python/mlc_llm/serve/entrypoints/debug_entrypoints.py",
    "content": "\"\"\"MLC LLM server debug entrypoints\"\"\"\n\nimport json\nfrom http import HTTPStatus\n\nimport fastapi\n\nfrom mlc_llm.protocol import error_protocol\nfrom mlc_llm.serve.server import ServerContext\n\napp = fastapi.APIRouter()\n\n################ /debug/dump_event_trace ################\n\n\n@app.post(\"/debug/dump_event_trace\")\nasync def debug_dump_event_trace(request: fastapi.Request):\n    \"\"\"Return the recorded events in Chrome Trace Event Format in JSON string.\n    The input request payload should have only one field, specifying the\n    model to query. For example: `{\"model\": \"Llama-2-7b-chat-hf-q0f16\"}`.\n    \"\"\"\n    # Get the raw request body as bytes\n    request_raw_data = await request.body()\n    request_json_str = request_raw_data.decode(\"utf-8\")\n    try:\n        # Parse the JSON string\n        request_dict = json.loads(request_json_str)\n    except json.JSONDecodeError:\n        return error_protocol.create_error_response(\n            HTTPStatus.BAD_REQUEST, message=f\"Invalid request {request_json_str}\"\n        )\n    if \"model\" not in request_dict:\n        return error_protocol.create_error_response(\n            HTTPStatus.BAD_REQUEST, message=f\"Invalid request {request_json_str}\"\n        )\n\n    # Check the requested model.\n    model = request_dict[\"model\"]\n\n    server_context: ServerContext = ServerContext.current()\n    async_engine = server_context.get_engine(model)\n\n    if async_engine is None:\n        return error_protocol.create_error_response(\n            HTTPStatus.BAD_REQUEST,\n            message=f'The requested model \"{model}\" is not served.',\n        )\n    if async_engine.state.trace_recorder is None:\n        return error_protocol.create_error_response(\n            HTTPStatus.BAD_REQUEST,\n            message=f'The requested model \"{model}\" does not enable tracing',\n        )\n\n    return json.loads(async_engine.state.trace_recorder.dump_json())\n\n\n################ /debug/cuda_profiler_start/end ################\n\n\n@app.post(\"/debug/cuda_profiler_start\")\nasync def debug_cuda_profiler_start(_request: fastapi.Request):\n    \"\"\"Start the cuda profiler for the engine. Only for debug purpose.\"\"\"\n    server_context: ServerContext = ServerContext.current()\n    # Since the CUDA profiler is process-wise, call the function for one model is sufficient.\n    for model in server_context.get_model_list():\n        async_engine = server_context.get_engine(model)\n        async_engine._debug_call_func_on_all_worker(  # pylint: disable=protected-access\n            \"mlc.debug_cuda_profiler_start\"\n        )\n        break\n\n\n@app.post(\"/debug/cuda_profiler_stop\")\nasync def debug_cuda_profiler_stop(_request: fastapi.Request):\n    \"\"\"Stop the cuda profiler for the engine. Only for debug purpose.\"\"\"\n    server_context: ServerContext = ServerContext.current()\n    # Since the CUDA profiler is process-wise, call the function for one model is sufficient.\n    for model in server_context.get_model_list():\n        async_engine = server_context.get_engine(model)\n        async_engine._debug_call_func_on_all_worker(  # pylint: disable=protected-access\n            \"mlc.debug_cuda_profiler_stop\"\n        )\n        break\n\n\n@app.post(\"/debug/dump_engine_metrics\")\nasync def debug_dump_engine_metrics(request: fastapi.Request):\n    \"\"\"Dump the engine metrics for the engine. Only for debug purpose.\"\"\"\n    # Get the raw request body as bytes\n    request_raw_data = await request.body()\n    request_json_str = request_raw_data.decode(\"utf-8\")\n    try:\n        # Parse the JSON string\n        request_dict = json.loads(request_json_str)\n    except json.JSONDecodeError:\n        return error_protocol.create_error_response(\n            HTTPStatus.BAD_REQUEST, message=f\"Invalid request {request_json_str}\"\n        )\n\n    # Check the requested model.\n    model = request_dict.get(\"model\", None)\n\n    server_context: ServerContext = ServerContext.current()\n    async_engine = server_context.get_engine(model)\n    res = await async_engine.metrics()\n    return res\n\n\n@app.post(\"/debug/reset_engine\")\nasync def debug_reset_engine_stats(request: fastapi.Request):\n    \"\"\"Reset the engine, clean up all running data and metrics.\"\"\"\n    # Get the raw request body as bytes\n    request_raw_data = await request.body()\n    request_json_str = request_raw_data.decode(\"utf-8\")\n    try:\n        # Parse the JSON string\n        request_dict = json.loads(request_json_str)\n    except json.JSONDecodeError:\n        return error_protocol.create_error_response(\n            HTTPStatus.BAD_REQUEST, message=f\"Invalid request {request_json_str}\"\n        )\n    if \"model\" not in request_dict:\n        return error_protocol.create_error_response(\n            HTTPStatus.BAD_REQUEST, message=f\"Invalid request {request_json_str}\"\n        )\n\n    # Check the requested model.\n    model = request_dict[\"model\"]\n\n    server_context: ServerContext = ServerContext.current()\n    async_engine = server_context.get_engine(model)\n    async_engine.reset()\n"
  },
  {
    "path": "python/mlc_llm/serve/entrypoints/metrics_entrypoints.py",
    "content": "\"\"\"MLC LLM server metrics entrypoints\"\"\"\n\nimport fastapi\nfrom fastapi.responses import PlainTextResponse\n\nfrom mlc_llm.serve.server import ServerContext\n\napp = fastapi.APIRouter()\n\n################ /metrics ################\n\n\n@app.get(\"/metrics\", response_class=PlainTextResponse)\nasync def metrics(_request: fastapi.Request):\n    \"\"\"Start the cuda profiler for the engine. Only for debug purpose.\"\"\"\n    server_context: ServerContext = ServerContext.current()\n    # Use the metrics from first engine for now\n    # TODO(mlc-team): consider refactor server context to\n    # single engine since multiple AsyncMLCEngine do not work well with each other\n    # We need to work within the internal engine instead.\n    for model in server_context.get_model_list():\n        async_engine = server_context.get_engine(model)\n        return (await async_engine.metrics()).prometheus_text()\n"
  },
  {
    "path": "python/mlc_llm/serve/entrypoints/microserving_entrypoints.py",
    "content": "\"\"\"MicroServing server entrypoints in MLC LLM\"\"\"\n\nimport fastapi\n\nfrom mlc_llm.protocol.debug_protocol import DisaggConfig\nfrom mlc_llm.protocol.microserving_protocol import (\n    PrepRecvRequest,\n    PrepRecvResponse,\n    RemoteSendRequest,\n    StartGenerateRequest,\n)\nfrom mlc_llm.protocol.openai_api_protocol import StreamOptions\n\nfrom .openai_entrypoints import request_completion\n\napp = fastapi.APIRouter()\n\n\n################ MicroServing Endpoints ################\n\n\n@app.post(\"/microserving/prep_recv\")\nasync def prep_recv(request: PrepRecvRequest, raw_request: fastapi.Request) -> PrepRecvResponse:\n    \"\"\"Handle the microserving request for receive preparation.\n    Match the prompt in the prefix cache (when enabled),\n    allocate entries in the KV cache to prepare receiving the KV data of the prompt.\n    Return the matched prefix length and the allocated KV entry metadata.\n    \"\"\"\n    request.debug_config.disagg_config = DisaggConfig(\n        kind=\"prepare_receive\",\n        kv_window_begin=0,  # always zero for prepare_receive\n        kv_window_end=request.end,\n    )\n    request.stream_options = StreamOptions(include_usage=True)\n    request.stream = False\n\n    response = await request_completion(request=request, raw_request=raw_request)\n    assert response.usage is not None\n    assert response.usage.extra is not None\n    assert \"prefix_matched_length\" in response.usage.extra\n    assert \"kv_append_metadata\" in response.usage.extra\n    return PrepRecvResponse(\n        prefix_matched_length=response.usage.extra[\"prefix_matched_length\"],\n        kv_append_metadata=response.usage.extra[\"kv_append_metadata\"],\n    )\n\n\n@app.post(\"/microserving/remote_send\")\nasync def remote_send(request: RemoteSendRequest, raw_request: fastapi.Request):\n    \"\"\"Compute and generate the KV data of the prompt in the specified KV window.\n    Send the KV data to the destination server.\"\"\"\n    request.debug_config.disagg_config = DisaggConfig(\n        kind=\"remote_send\",\n        kv_window_begin=request.begin,\n        kv_window_end=request.end,\n        kv_append_metadata=request.kv_addr_info,\n        dst_group_offset=request.recv_rank,\n    )\n    request.stream_options = StreamOptions(include_usage=True)\n    request.stream = False\n\n    await request_completion(request=request, raw_request=raw_request)\n    return {}\n\n\n@app.post(\"/microserving/start_generate\")\nasync def start_generate(request: StartGenerateRequest, raw_request: fastapi.Request):\n    \"\"\"Prefill the prompt in the specified KV window, and start decode.\"\"\"\n    request.debug_config.disagg_config = DisaggConfig(\n        kind=\"start_generation\",\n        kv_window_begin=request.begin,\n    )\n    return await request_completion(request=request, raw_request=raw_request)\n"
  },
  {
    "path": "python/mlc_llm/serve/entrypoints/openai_entrypoints.py",
    "content": "\"\"\"OpenAI API-compatible server entrypoints in MLC LLM\"\"\"\n\n# pylint: disable=too-many-locals,too-many-return-statements,too-many-statements,fixme\nimport base64\nimport struct\nfrom http import HTTPStatus\nfrom typing import AsyncGenerator, List, Optional\n\nimport fastapi\nimport numpy as np\n\nfrom mlc_llm.protocol import error_protocol\nfrom mlc_llm.protocol.openai_api_protocol import (\n    ChatCompletionRequest,\n    CompletionLogProbs,\n    CompletionRequest,\n    EmbeddingObject,\n    EmbeddingRequest,\n    EmbeddingResponse,\n    EmbeddingUsage,\n    ListResponse,\n    LogProbsContent,\n    ModelResponse,\n)\nfrom mlc_llm.serve import engine_base, engine_utils\nfrom mlc_llm.serve.server import ServerContext\n\n\ndef verify_api_key(request: fastapi.Request):\n    \"\"\"Function to verify API key\"\"\"\n    server_context = ServerContext.current()\n    # Only perform verification when API key is configured\n    if server_context is not None and server_context.api_key is not None:\n        provided_key = request.headers.get(\"Authorization\", \"\").replace(\"Bearer \", \"\")\n        if provided_key != server_context.api_key:\n            raise fastapi.HTTPException(status_code=401, detail=\"Invalid API Key\")\n\n\napp = fastapi.APIRouter(dependencies=[fastapi.Depends(verify_api_key)])\n\n\n################ v1/embeddings ################\n\n\n@app.post(\"/v1/embeddings\")\nasync def request_embedding(request: EmbeddingRequest):\n    \"\"\"OpenAI-compatible embedding API.\n    API reference: https://platform.openai.com/docs/api-reference/embeddings/create\n    \"\"\"\n    server_context: ServerContext = ServerContext.current()\n    embedding_engine = server_context.get_embedding_engine(request.model)\n    if embedding_engine is None:\n        return error_protocol.create_error_response(\n            HTTPStatus.BAD_REQUEST,\n            message=f'The requested model \"{request.model}\" is not served '\n            f\"as an embedding model.\",\n        )\n\n    # Normalize input to List[str]\n    inputs: List[str]\n    if isinstance(request.input, str):\n        inputs = [request.input]\n    elif (\n        isinstance(request.input, list)\n        and len(request.input) > 0\n        and isinstance(request.input[0], str)\n    ):\n        inputs = list(request.input)  # type: ignore[arg-type]\n    else:\n        # Token ID inputs (List[int] or List[List[int]]) — decode back to strings\n        if isinstance(request.input[0], int):\n            inputs = [embedding_engine.tokenizer.decode(request.input)]  # type: ignore[arg-type]\n        else:\n            inputs = [\n                embedding_engine.tokenizer.decode(ids)  # type: ignore[arg-type]\n                for ids in request.input\n            ]\n\n    # Run embedding inference (async — does not block the event loop)\n    try:\n        embeddings, total_tokens = await embedding_engine.async_embed(inputs)\n    except Exception as exc:  # pylint: disable=broad-except\n        return error_protocol.create_error_response(\n            HTTPStatus.INTERNAL_SERVER_ERROR,\n            message=f\"Embedding inference failed: {exc}\",\n        )\n\n    # Optional: truncate dimensions (Matryoshka-style).\n    # This is API-level renormalization after dimension truncation,\n    # independent of model metadata normalize. Always renormalize\n    # truncated vectors to maintain unit length per OpenAI API contract.\n    if request.dimensions is not None:\n        for i, emb in enumerate(embeddings):\n            vec = np.array(emb[: request.dimensions], dtype=np.float32)\n            norm = np.linalg.norm(vec)\n            if norm > 1e-12:\n                vec = vec / norm\n            embeddings[i] = vec.tolist()\n\n    # Build response data\n    resp_data = []\n    for i, emb in enumerate(embeddings):\n        if request.encoding_format == \"base64\":\n            binary = struct.pack(f\"<{len(emb)}f\", *emb)\n            resp_data.append(\n                EmbeddingObject(\n                    embedding=base64.b64encode(binary).decode(\"utf-8\"),\n                    index=i,\n                )\n            )\n        else:\n            resp_data.append(EmbeddingObject(embedding=emb, index=i))\n\n    return EmbeddingResponse(\n        data=resp_data,\n        model=request.model,\n        usage=EmbeddingUsage(prompt_tokens=total_tokens, total_tokens=total_tokens),\n    )\n\n\n################ v1/models ################\n\n\n@app.get(\"/v1/models\")\nasync def request_models() -> ListResponse:\n    \"\"\"OpenAI-compatible served model query API.\n    API reference: https://platform.openai.com/docs/api-reference/models\n    \"\"\"\n    server_context: ServerContext = ServerContext.current()\n    return ListResponse(data=[ModelResponse(id=model) for model in server_context.get_model_list()])\n\n\n################ v1/completions ################\n\n\n@app.post(\"/v1/completions\")\nasync def request_completion(request: CompletionRequest, raw_request: fastapi.Request):\n    \"\"\"OpenAI-compatible completion API.\n    API reference: https://platform.openai.com/docs/api-reference/completions/create\n    \"\"\"\n    # - Check the requested model.\n    server_context: ServerContext = ServerContext.current()\n    request_final_usage_include_extra = server_context.enable_debug\n    request_include_debug_config = server_context.enable_debug\n\n    if not request_include_debug_config:\n        request.debug_config = None\n\n    async_engine = server_context.get_engine(request.model)\n    if async_engine is None:\n        return error_protocol.create_error_response(\n            HTTPStatus.BAD_REQUEST,\n            message=f'The requested model \"{request.model}\" is not served.',\n        )\n    # FIXME: This is a temporary solution to make sure\n    # prep_recv, remote_send and start_generation process the same request\n    request_id = request.user if request.user is not None else f\"cmpl-{engine_utils.random_uuid()}\"\n\n    # Streaming response.\n    if request.stream:\n        # We manually get the first response from generator to\n        # capture potential exceptions in this scope, rather then\n        # the StreamingResponse scope.\n        stream_generator = async_engine._handle_completion(  # pylint: disable=protected-access\n            request,\n            request_id,\n            request_final_usage_include_extra=request_final_usage_include_extra,\n        )\n        first_response = await anext(  # type: ignore  # pylint: disable=undefined-variable\n            stream_generator\n        )\n\n        async def completion_stream_generator() -> AsyncGenerator[str, None]:\n            if isinstance(first_response, StopAsyncIteration):\n                yield \"data: [DONE]\\n\\n\"\n                return\n            yield f\"data: {first_response.model_dump_json(by_alias=True)}\\n\\n\"\n            async for response in stream_generator:\n                yield f\"data: {response.model_dump_json(by_alias=True)}\\n\\n\"\n            yield \"data: [DONE]\\n\\n\"\n\n        return fastapi.responses.StreamingResponse(\n            completion_stream_generator(), media_type=\"text/event-stream\"\n        )\n\n    # Normal response.\n    request_final_usage = None\n    output_texts = [\"\"] * request.n\n    finish_reasons: List[Optional[str]] = [None] * request.n\n    logprob_results: List[Optional[CompletionLogProbs]] = [None] * request.n\n\n    async for response in async_engine._handle_completion(  # pylint: disable=protected-access\n        request,\n        request_id,\n        request_final_usage_include_extra=request_final_usage_include_extra,\n    ):\n        if await raw_request.is_disconnected():\n            # In non-streaming cases, the engine will not be notified\n            # when the request is disconnected.\n            # Therefore, we check if it is disconnected each time,\n            # and explicitly return.\n            # Note that requesta abort is triggered when the async for and funciton scope ends.\n            return error_protocol.create_error_response(\n                HTTPStatus.BAD_REQUEST, message=\"The request has disconnected\"\n            )\n        # this is the final chunk\n        if response.usage is not None:\n            request_final_usage = response.usage\n            # remove extra information if debug is not enabled\n            if not server_context.enable_debug:\n                request_final_usage.extra = None\n            continue\n        for choice in response.choices:\n            output_texts[choice.index] += choice.text\n            if choice.finish_reason is not None and finish_reasons[choice.index] is None:\n                finish_reasons[choice.index] = choice.finish_reason\n            if choice.logprobs is not None:\n                if logprob_results[choice.index] is None:\n                    logprob_results[choice.index] = choice.logprobs\n                else:\n                    logprob_results[choice.index].token_logprobs.extend(\n                        choice.logprobs.token_logprobs\n                    )\n                    logprob_results[choice.index].tokens.extend(choice.logprobs.tokens)\n                    logprob_results[choice.index].top_logprobs.extend(choice.logprobs.top_logprobs)\n\n    return engine_base.wrap_completion_response(\n        request_id=request_id,\n        model=request.model,\n        output_texts=output_texts,\n        finish_reasons=finish_reasons,\n        logprob_results=logprob_results,\n        usage=request_final_usage,\n    )\n\n\n################ v1/chat/completions ################\n\n\n@app.post(\"/v1/chat/completions\")\nasync def request_chat_completion(\n    request: ChatCompletionRequest, raw_request: fastapi.Request\n):  # pylint: disable=too-many-branches\n    \"\"\"OpenAI-compatible chat completion API.\n    API reference: https://platform.openai.com/docs/api-reference/chat\n    \"\"\"\n    # - Check the requested model.\n    server_context: ServerContext = ServerContext.current()\n    request_final_usage_include_extra = server_context.enable_debug\n    request_include_debug_config = server_context.enable_debug\n\n    if not request_include_debug_config:\n        request.debug_config = None\n\n    async_engine = server_context.get_engine(request.model)\n    if async_engine is None:\n        return error_protocol.create_error_response(\n            HTTPStatus.BAD_REQUEST,\n            message=f'The requested model \"{request.model}\" is not served.',\n        )\n    # FIXME: This is a temporary solution to make sure\n    # prep_recv, remote_send and start_generation process the same request\n    request_id = (\n        request.user if request.user is not None else f\"chatcmpl-{engine_utils.random_uuid()}\"\n    )\n\n    # Streaming response.\n    if request.stream:\n        # We manually get the first response from generator to\n        # capture potential exceptions in this scope, rather then\n        # the StreamingResponse scope.\n        stream_generator = async_engine._handle_chat_completion(  # pylint: disable=protected-access\n            request,\n            request_id,\n            request_final_usage_include_extra=request_final_usage_include_extra,\n        )\n        first_response = await anext(  # type: ignore  # pylint: disable=undefined-variable\n            stream_generator\n        )\n\n        async def completion_stream_generator() -> AsyncGenerator[str, None]:\n            if isinstance(first_response, StopAsyncIteration):\n                yield \"data: [DONE]\\n\\n\"\n                return\n            yield f\"data: {first_response.model_dump_json(by_alias=True)}\\n\\n\"\n            async for response in stream_generator:\n                yield f\"data: {response.model_dump_json(by_alias=True)}\\n\\n\"\n            yield \"data: [DONE]\\n\\n\"\n\n        return fastapi.responses.StreamingResponse(\n            completion_stream_generator(), media_type=\"text/event-stream\"\n        )\n\n    # Normal response.\n    request_final_usage = None\n    output_texts = [\"\" for _ in range(request.n)]\n    finish_reasons: List[Optional[str]] = [None for _ in range(request.n)]\n    logprob_results: Optional[List[List[LogProbsContent]]] = (\n        [[] for _ in range(request.n)] if request.logprobs else None\n    )\n\n    async for response in async_engine._handle_chat_completion(  # pylint: disable=protected-access\n        request,\n        request_id,\n        request_final_usage_include_extra=request_final_usage_include_extra,\n    ):\n        if await raw_request.is_disconnected():\n            # In non-streaming cases, the engine will not be notified\n            # when the request is disconnected.\n            # Therefore, we check if it is disconnected each time,\n            # no need to explicitly abort, as the chat completion\n            # return will trigger abort call\n            return error_protocol.create_error_response(\n                HTTPStatus.BAD_REQUEST, message=\"The request has disconnected\"\n            )\n        # usage is always the last chunk\n        if response.usage is not None:\n            request_final_usage = response.usage\n            # remove extra information if debug is not enabled\n            if not server_context.enable_debug:\n                request_final_usage.extra = None\n\n        for choice in response.choices:\n            assert isinstance(choice.delta.content, str)\n            output_texts[choice.index] += choice.delta.content\n            if choice.finish_reason is not None and finish_reasons[choice.index] is None:\n                finish_reasons[choice.index] = choice.finish_reason\n            if choice.logprobs is not None:\n                assert logprob_results is not None\n                logprob_results[choice.index] += choice.logprobs.content\n\n    assert all(finish_reason is not None for finish_reason in finish_reasons)\n    use_function_calling, tool_calls_list = engine_base.process_function_call_output(\n        output_texts, finish_reasons\n    )\n\n    return engine_base.wrap_chat_completion_response(\n        request_id=request_id,\n        model=request.model,\n        output_texts=output_texts,\n        finish_reasons=finish_reasons,\n        tool_calls_list=tool_calls_list,\n        logprob_results=logprob_results,\n        use_function_calling=use_function_calling,\n        usage=request_final_usage,\n    )\n"
  },
  {
    "path": "python/mlc_llm/serve/event_trace_recorder.py",
    "content": "\"\"\"The event trace recorder in MLC LLM serving\"\"\"\n\nimport tvm_ffi\nfrom tvm.runtime import Object\n\nfrom . import _ffi_api\n\n\n@tvm_ffi.register_object(\"mlc.serve.EventTraceRecorder\")  # pylint: disable=protected-access\nclass EventTraceRecorder(Object):\n    \"\"\"The event trace recorder for requests.\"\"\"\n\n    def __init__(self) -> None:  # pylint: disable=super-init-not-called\n        \"\"\"Initialize a trace recorder.\"\"\"\n        self.__init_handle_by_constructor__(\n            _ffi_api.EventTraceRecorder  # type: ignore  # pylint: disable=no-member\n        )\n\n    def add_event(self, request_id: str, event: str) -> None:\n        \"\"\"Record a event for the input request in the trace recorder.\n\n        Parameters\n        ----------\n        request_id : str\n            The subject request of the event.\n\n        event : str\n            The event in a string name.\n            It can have one of the following patterns:\n            - \"start xxx\", which marks the start of event \"xxx\",\n            - \"finish xxx\", which marks the finish of event \"xxx\",\n            - \"yyy\", which marks the instant event \"yyy\".\n            The \"starts\" and \"finishes\" will be automatically paired in the trace recorder.\n        \"\"\"\n        return _ffi_api.EventTraceRecorderAddEvent(  # type: ignore  # pylint: disable=no-member\n            self, request_id, event\n        )\n\n    def dump_json(self) -> str:\n        \"\"\"Dump the logged events in Chrome Trace Event Format in JSON string.\"\"\"\n        return _ffi_api.EventTraceRecorderDumpJSON(self)  # type: ignore  # pylint: disable=no-member\n"
  },
  {
    "path": "python/mlc_llm/serve/radix_tree.py",
    "content": "\"\"\"The Paged Radix Tree class.\"\"\"\n\nfrom typing import List, Tuple, Union\n\nimport tvm_ffi\nfrom tvm.runtime import Object, ShapeTuple\n\nfrom . import _ffi_api\n\n\n@tvm_ffi.register_object(\"mlc.serve.PagedRadixTree\")  # pylint: disable=protected-access\nclass PagedRadixTree(Object):\n    \"\"\"The paged radix tree to manage prefix and sequence.\"\"\"\n\n    def __init__(self):  # pylint: disable=super-init-not-called\n        \"\"\"\n        Constructor of paged radix tree.\n        \"\"\"\n        self.__init_handle_by_constructor__(_ffi_api.PagedRadixTree)  # type: ignore  # pylint: disable=no-member\n\n    def match(self, tokens: Union[ShapeTuple, List, Tuple]) -> Tuple[int, ShapeTuple]:\n        \"\"\"\n        Get all sequences with longest common prefix with given prefix tokens.\n\n        Parameters\n        ----------\n        tokens : Union[ShapeTuple, List, Tuple]\n            The prefix tokens for reference.\n\n        Returns\n        ------\n        matched_offset : int\n            The matched prefix length.\n        seq_ids : ShapeTuple\n            The array of matched sequence indice.\n        \"\"\"\n        if isinstance(tokens, (list, tuple)):\n            tokens = ShapeTuple(tokens)\n        output = _ffi_api.PagedRadixTreeMatchPrefix(self, tokens)  # type: ignore  # pylint: disable=no-member\n        if len(output) == 1:\n            return output[0], []\n        return output[0], output[1:]\n\n    def add(self, seq_id: int) -> None:\n        \"\"\"\n        Add an empty sequence.\n\n        Parameters\n        ----------\n        seq_id : int\n            The sequence ID for index.\n        \"\"\"\n        _ffi_api.PagedRadixTreeAddSequence(self, seq_id)  # type: ignore  # pylint: disable=no-member\n\n    def remove(self, seq_id: int) -> None:\n        \"\"\"\n        Remove a sequence.\n\n        Parameters\n        ----------\n        seq_id : int\n            The sequence ID to remove.\n        \"\"\"\n        _ffi_api.PagedRadixTreeRemoveSequence(self, seq_id)  # type: ignore  # pylint: disable=no-member\n\n    def extend(self, seq_id: int, tokens: Union[ShapeTuple, List, Tuple]) -> None:\n        \"\"\"\n        Extend a sequence with given tokens.\n\n        Parameters\n        ----------\n        seq_id : int\n            The sequence ID for index.\n        tokens : Union[ShapeTuple, List, Tuple]\n            The given tokens to extend.\n        \"\"\"\n        if isinstance(tokens, (list, tuple)):\n            tokens = ShapeTuple(tokens)\n        _ffi_api.PagedRadixTreeExtendSequence(self, seq_id, tokens)  # type: ignore  # pylint: disable=no-member\n\n    def rollback(self, seq_id: int, num_tokens: int) -> None:\n        \"\"\"\n        Roll back a sequence by number of tokens.\n\n        Parameters\n        ----------\n        seq_id : int\n            The sequence ID for index.\n        num_tokens : int\n            The number of tokens to be rolled back.\n        \"\"\"\n        _ffi_api.PagedRadixTreeRollBackSequence(self, seq_id, num_tokens)  # type: ignore  # pylint: disable=no-member\n\n    def fork(self, seq_id: int, parent_seq_id: int, forked_offset: int) -> None:\n        \"\"\"\n        Fork a sequence from parent sequence at given position.\n\n        Parameters\n        ----------\n        seq_id : int\n            The new sequence ID.\n        parent_seq_id : int\n            The parent sequence ID to fork from.\n        forked_offset : int\n            The position of parent sequence to fork at.\n            The valid value is [1, length of forked sequence].\n            If the position equals the length of forked sequence,\n            the new sequence will copy the entire forked sequence.\n        \"\"\"\n        _ffi_api.PagedRadixTreeForkSequence(self, seq_id, parent_seq_id, forked_offset)  # type: ignore  # pylint: disable=no-member\n\n    def get(self, seq_id: int) -> ShapeTuple:\n        \"\"\"\n        Get a sequence's all tokens.\n\n        Parameters\n        ----------\n        seq_id : int\n            The sequence ID for index.\n\n        Returns\n        ------\n        tokens : ShapeTuple\n            The sequence tokens.\n        \"\"\"\n        return _ffi_api.PagedRadixTreeGetSequence(self, seq_id)  # type: ignore  # pylint: disable=no-member\n\n    def get_length(self, seq_id: int) -> int:\n        \"\"\"\n        Get a sequence's length.\n\n        Parameters\n        ----------\n        seq_id : int\n            The sequence ID for index.\n\n        Returns\n        ------\n        length : int\n            The sequence length.\n        \"\"\"\n        return _ffi_api.PagedRadixTreeGetSequenceLength(self, seq_id)  # type: ignore  # pylint: disable=no-member\n\n    def free_capacity(self) -> int:\n        \"\"\"\n        Get the remaining token capacity of the paged radix tree.\n\n        Returns\n        ------\n        capacity : int\n            The remaining token capacity of the paged radix tree.\n        \"\"\"\n        return _ffi_api.PagedRadixTreeFreeCapacity(self)  # type: ignore  # pylint: disable=no-member\n"
  },
  {
    "path": "python/mlc_llm/serve/request.py",
    "content": "\"\"\"The request class in MLC LLM serving\"\"\"\n\nfrom typing import List\n\nimport tvm_ffi\nfrom tvm.runtime import Object\n\nfrom mlc_llm.protocol.generation_config import GenerationConfig\n\nfrom . import _ffi_api\nfrom .data import Data\n\n\n@tvm_ffi.register_object(\"mlc.serve.Request\")  # pylint: disable=protected-access\nclass Request(Object):\n    \"\"\"The user submitted text-generation request, which contains\n    a unique request id, a list of multi-modal inputs, a set of\n    generation configuration parameters.\n\n    Note\n    ----\n    Do not explicitly construct this class.\n    Construct this object via engine.create_request functions.\n    \"\"\"\n\n    @property\n    def inputs(self) -> List[Data]:\n        \"\"\"The inputs of the request.\"\"\"\n        return _ffi_api.RequestGetInputs(self)  # type: ignore  # pylint: disable=no-member\n\n    @property\n    def generation_config(self) -> GenerationConfig:\n        \"\"\"The generation config of the request.\"\"\"\n        return GenerationConfig.model_validate_json(\n            _ffi_api.RequestGetGenerationConfigJSON(self)  # type: ignore  # pylint: disable=no-member\n        )\n"
  },
  {
    "path": "python/mlc_llm/serve/server/__init__.py",
    "content": "\"\"\"The server related data structure and tools in MLC LLM serve.\"\"\"\n\nfrom .popen_server import PopenServer\nfrom .server_context import ServerContext\n"
  },
  {
    "path": "python/mlc_llm/serve/server/popen_server.py",
    "content": "\"\"\"The MLC LLM server launched in a subprocess.\"\"\"\n\nimport os\nimport subprocess\nimport sys\nimport time\nfrom pathlib import Path\nfrom typing import Literal, Optional, Union\n\nimport psutil\nimport requests\nfrom tvm.runtime import Device\n\nfrom mlc_llm.serve.config import EngineConfig\nfrom mlc_llm.serve.engine_base import _check_engine_config\n\n\nclass PopenServer:  # pylint: disable=too-many-instance-attributes\n    \"\"\"The wrapper of MLC LLM server, which runs the server in\n    a background subprocess.\n\n    This server can be used for debugging purposes.\n    \"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments\n        self,\n        model: str,\n        device: Union[str, Device] = \"auto\",\n        *,\n        model_lib: Optional[str] = None,\n        mode: Literal[\"local\", \"interactive\", \"server\"] = \"local\",\n        engine_config: Optional[EngineConfig] = None,\n        enable_debug: bool = True,\n        enable_tracing: bool = False,\n        host: str = \"127.0.0.1\",\n        port: int = 8082,\n    ) -> None:\n        \"\"\"Please check out `python/mlc_llm/cli/serve.py` for the server arguments.\"\"\"\n        # - Check the fields fields of `engine_config`.\n        if engine_config is None:\n            engine_config = EngineConfig()\n        _check_engine_config(model, model_lib, mode, engine_config)\n\n        self.model = model\n        self.model_lib = model_lib\n        self.device = device\n        self.mode = mode\n        self.enable_debug = enable_debug\n        self.engine_config = engine_config\n        self.enable_tracing = enable_tracing\n        self.enable_debug = enable_debug\n        self.host = host\n        self.port = port\n        self._proc: Optional[subprocess.Popen] = None\n\n        self.base_url = \"\"\n        self.openai_v1_base_url = \"\"\n\n    def start(  # pylint: disable=too-many-branches,too-many-statements\n        self, extra_env=None\n    ) -> None:\n        \"\"\"Launch the server in a popen subprocess.\n        Wait until the server becomes ready before return.\n        \"\"\"\n        extra_env = extra_env or {}\n        cmd = [sys.executable]\n        cmd += [\"-m\", \"mlc_llm\", \"serve\", self.model]\n        if self.model_lib is not None:\n            cmd += [\"--model-lib\", self.model_lib]\n        cmd += [\"--device\", self.device]\n\n        if self.enable_debug:\n            cmd += [\"--enable-debug\"]\n\n        if self.mode is not None:\n            cmd += [\"--mode\", self.mode]\n\n        if len(self.engine_config.additional_models) > 0:\n            args_additional_model = []\n            for additional_model in self.engine_config.additional_models:\n                if isinstance(additional_model, str):\n                    args_additional_model.append(additional_model)\n                else:\n                    args_additional_model.append(additional_model[0] + \",\" + additional_model[1])\n            cmd += [\"--additional-models\", *args_additional_model]\n        cmd += [\"--speculative-mode\", self.engine_config.speculative_mode]\n        cmd += [\"--prefix-cache-mode\", self.engine_config.prefix_cache_mode]\n\n        args_overrides = []\n        if self.engine_config.max_num_sequence is not None:\n            args_overrides.append(f\"max_num_sequence={self.engine_config.max_num_sequence}\")\n        if self.engine_config.max_total_sequence_length is not None:\n            args_overrides.append(\n                f\"max_total_seq_length={self.engine_config.max_total_sequence_length}\"\n            )\n        if self.engine_config.prefill_chunk_size is not None:\n            args_overrides.append(f\"prefill_chunk_size={self.engine_config.prefill_chunk_size}\")\n        if self.engine_config.max_history_size is not None:\n            args_overrides.append(f\"max_history_size={self.engine_config.max_history_size}\")\n        if self.engine_config.gpu_memory_utilization is not None:\n            args_overrides.append(\n                f\"gpu_memory_utilization={self.engine_config.gpu_memory_utilization}\"\n            )\n        if self.engine_config.spec_draft_length is not None:\n            args_overrides.append(f\"spec_draft_length={self.engine_config.spec_draft_length}\")\n        if self.engine_config.prefix_cache_max_num_recycling_seqs is not None:\n            args_overrides.append(\n                \"prefix_cache_max_num_recycling_seqs=\"\n                + str(self.engine_config.prefix_cache_max_num_recycling_seqs)\n            )\n        if len(args_overrides) > 0:\n            cmd += [\"--overrides\", \";\".join(args_overrides)]\n\n        if self.enable_tracing:\n            cmd += [\"--enable-tracing\"]\n        if self.enable_debug:\n            cmd += [\"--enable-debug\"]\n\n        cmd += [\"--host\", self.host]\n        cmd += [\"--port\", str(self.port)]\n        process_path = str(Path(__file__).resolve().parents[4])\n        final_env = os.environ.copy()\n        for key, value in extra_env.items():\n            final_env[key] = value\n        self._proc = subprocess.Popen(  # pylint: disable=consider-using-with\n            cmd, cwd=process_path, env=final_env\n        )\n        # NOTE: DO NOT USE `stdout=subprocess.PIPE, stderr=subprocess.PIPE`\n        # in subprocess.Popen here. PIPE has a fixed-size buffer with may block\n        # and hang forever.\n\n        # Try to query the server until it is ready.\n        self.base_url = f\"http://{self.host}:{str(self.port)}\"\n        self.openai_v1_base_url = f\"http://{self.host}:{str(self.port)}/v1\"\n        openai_v1_models_url = f\"{self.base_url}/v1/models\"\n\n        query_result = None\n        timeout = 120\n        attempts = 0.0\n        while query_result is None and attempts < timeout:\n            try:\n                query_result = requests.get(openai_v1_models_url, timeout=60)\n                if query_result.status_code != 200:\n                    query_result = None\n                    attempts += 0.1\n                    time.sleep(0.1)\n            except:  # pylint: disable=bare-except\n                attempts += 0.1\n                time.sleep(0.1)\n\n        # Check if the subprocess terminates unexpectedly or\n        # the queries reach the timeout.\n        process_return_code = self._proc.poll()\n        if process_return_code is not None:\n            raise RuntimeError(\n                \"The server fails to launch. \"\n                f'Please check if \"{self.model}\" is a valid model compiled by MLC LLM.'\n            )\n        if attempts == timeout:\n            self.terminate()\n            raise RuntimeError(f\"The server fails to launch in {timeout} seconds.\")\n\n    def terminate(self) -> None:\n        \"\"\"Terminate the server subprocess.\"\"\"\n        if self._proc is None:\n            return\n\n        # Kill all the child processes.\n        def kill_child_processes():\n            try:\n                parent = psutil.Process(self._proc.pid)\n                children = parent.children(recursive=True)\n            except psutil.NoSuchProcess:\n                return\n\n            for process in children:\n                try:\n                    process.kill()\n                except psutil.NoSuchProcess:\n                    pass\n\n        kill_child_processes()\n\n        # Kill the process.\n        try:\n            self._proc.kill()\n        except OSError:\n            pass\n\n        # Join the process to avoid zombies.\n        try:\n            self._proc.wait(timeout=10.0)\n        except subprocess.TimeoutExpired:\n            pass\n        self._proc = None\n\n    def __enter__(self):\n        \"\"\"Start the server.\"\"\"\n        self.start()\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        \"\"\"Terminate the server.\"\"\"\n        self.terminate()\n"
  },
  {
    "path": "python/mlc_llm/serve/server/server_context.py",
    "content": "\"\"\"Server context that shared by multiple entrypoint files.\"\"\"\n\nfrom typing import TYPE_CHECKING, Dict, List, Optional\n\nfrom ..engine import AsyncMLCEngine\n\nif TYPE_CHECKING:\n    from ..embedding_engine import AsyncEmbeddingEngine\n\n\nclass ServerContext:\n    \"\"\"The global server context, including the running models\n    and corresponding async engines.\n    \"\"\"\n\n    server_context: Optional[\"ServerContext\"] = None\n    enable_debug: bool = False\n\n    def __init__(self) -> None:\n        self._models: Dict[str, AsyncMLCEngine] = {}\n        self._embedding_engines: Dict[str, \"AsyncEmbeddingEngine\"] = {}\n        self.api_key: Optional[str] = None\n\n    def __enter__(self):\n        if ServerContext.server_context is not None:\n            raise RuntimeError(\"Server context already exists.\")\n        ServerContext.server_context = self\n        return self\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        for model_engine in self._models.values():\n            model_engine.terminate()\n        for emb_engine in self._embedding_engines.values():\n            emb_engine.terminate()\n        self._models.clear()\n        self._embedding_engines.clear()\n        ServerContext.server_context = None\n\n    @staticmethod\n    def current():\n        \"\"\"Returns the current ServerContext.\"\"\"\n        return ServerContext.server_context\n\n    def add_model(self, hosted_model: str, engine: AsyncMLCEngine) -> None:\n        \"\"\"Add a new model to the server context together with the engine.\"\"\"\n        if hosted_model in self._models:\n            raise RuntimeError(f\"Model {hosted_model} already running.\")\n        self._models[hosted_model] = engine\n\n    def get_engine(self, model: Optional[str]) -> Optional[AsyncMLCEngine]:\n        \"\"\"Get the async engine of the requested model, or the unique async engine\n        if only one engine is served.\"\"\"\n        if len(self._models) == 1:\n            return next(iter(self._models.values()))\n        return self._models.get(model, None)\n\n    def get_model_list(self) -> List[str]:\n        \"\"\"Get the list of all models on serve, including embedding models.\"\"\"\n        return list(self._models.keys()) + list(self._embedding_engines.keys())\n\n    def add_embedding_engine(self, hosted_model: str, engine: \"AsyncEmbeddingEngine\") -> None:\n        \"\"\"Add a new embedding model to the server context.\"\"\"\n        if hosted_model in self._embedding_engines:\n            raise RuntimeError(f\"Embedding model {hosted_model} already running.\")\n        self._embedding_engines[hosted_model] = engine\n\n    def get_embedding_engine(self, model: Optional[str]) -> Optional[\"AsyncEmbeddingEngine\"]:\n        \"\"\"Get the embedding engine of the requested model, or the unique\n        embedding engine if only one is served.\"\"\"\n        if len(self._embedding_engines) == 1:\n            return next(iter(self._embedding_engines.values()))\n        return self._embedding_engines.get(model, None)\n"
  },
  {
    "path": "python/mlc_llm/serve/sync_engine.py",
    "content": "\"\"\"The MLC LLM synchronized engine.\n\nNOTE: This engine defined in this file directly wraps the underlying\nEngine implementation in C++, is not optimized by multi-threading and\ndoes not offer standard OpenAI API interface.\n\nWe do not expose it and use it by default. As of now it mainly serves\nthe test and debug purpose because of its simplicity.\n\"\"\"\n\nimport json\nfrom typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union\n\nimport tvm\n\nfrom mlc_llm.protocol.generation_config import GenerationConfig\nfrom mlc_llm.serve import data\nfrom mlc_llm.serve.config import EngineConfig\nfrom mlc_llm.serve.engine_base import (\n    EngineMetrics,\n    _check_engine_config,\n    _parse_models,\n    _print_engine_mode_logging_msg,\n    _process_model_args,\n    detect_device,\n)\nfrom mlc_llm.serve.event_trace_recorder import EventTraceRecorder\nfrom mlc_llm.serve.request import Request\nfrom mlc_llm.support import logging\nfrom mlc_llm.tokenizers import TextStreamer, Tokenizer\n\nlogger = logging.getLogger(__name__)\n\n\ndef _create_tvm_module(\n    creator: str, ffi_funcs: Sequence[str], creator_args: Optional[List[Any]] = None\n) -> Dict[str, Callable]:\n    \"\"\"Internal method to create a module.\"\"\"\n    if creator_args is None:\n        creator_args = []\n    module = tvm.get_global_func(creator, allow_missing=False)(*creator_args)\n    return {key: module[key] for key in ffi_funcs}\n\n\nclass SyncMLCEngine:\n    \"\"\"The Python interface of synchronize request serving engine for MLC LLM.\n\n    The engine receives requests from the \"add_request\" method. For\n    an given request, the engine will keep generating new tokens for\n    the request until finish (under certain criterion). After finish,\n    the engine will return the generation result through the callback\n    function provided by the request.\n\n    NOTE: This engine directly wraps the underlying Engine implementation\n    in C++, is not optimized by multi-threading and does not offer standard\n    OpenAI API interface. We do not expose it and use it by default.\n    As of now it mainly serves the test and debug purpose because of its\n    simplicity.\n\n    Parameters\n    ----------\n    engine_config : Optional[EngineConfig]\n        Additional configurable arguments of MLC engine.\n        See class \"EngineConfig\" for more detail.\n\n    enable_tracing : bool\n        A boolean indicating if to enable event logging for requests.\n\n    request_stream_callback : Optional[Callable[[str, data.TokenData, Optional[str]], None]]\n        The provided callback function to handle the generation\n        output. It has the signature of `(str, data.TokenData, bool) -> None`,\n        where\n        - the first string is the request id,\n        - the TokenData contains the generated **delta** token ids since\n        the last invocation of the callback on the specific request,\n        - the optional string value denotes the finish reason if the\n        generation of the request is finished, or None if it has not finished.\n\n        The callback function is optional at construction, but it needs to\n        be set before the engine executing requests. This can be done via\n        the `set_request_stream_callback` method. Otherwise, the engine will raise\n        exception.\n    \"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments,too-many-locals\n        self,\n        model: str,\n        device: Union[str, tvm.runtime.Device] = \"auto\",\n        *,\n        model_lib: Optional[str] = None,\n        mode: Literal[\"local\", \"interactive\", \"server\"] = \"local\",\n        engine_config: Optional[EngineConfig] = None,\n        enable_tracing: bool = False,\n        request_stream_callback: Optional[Callable[[List[data.RequestStreamOutput]], None]] = None,\n    ):\n        # - Check the fields fields of `engine_config`.\n        if engine_config is None:\n            engine_config = EngineConfig()\n        _check_engine_config(\n            model,\n            model_lib,\n            mode,\n            engine_config,\n        )\n\n        # - Initialize model loading info.\n        models = _parse_models(model, model_lib, engine_config.additional_models)\n        if isinstance(device, str):\n            device = detect_device(device)\n        assert isinstance(device, tvm.runtime.Device)\n        (\n            model_args,\n            model_config_paths,\n            self.conv_template,\n        ) = _process_model_args(models, device, engine_config)\n\n        # - Load the raw model config into dict\n        self.model_config_dicts = []\n        for i, model_info in enumerate(models):\n            model_info.model_lib = model_args[i][1]\n            with open(model_config_paths[i], \"r\", encoding=\"utf-8\") as file:\n                self.model_config_dicts.append(json.load(file))\n\n        # - Print logging info for regarding the mode selection.\n        if engine_config.verbose:\n            _print_engine_mode_logging_msg(mode)\n\n        self._ffi = _create_tvm_module(\n            \"mlc.serve.create_engine\",\n            ffi_funcs=[\n                \"init\",\n                \"add_request\",\n                \"abort_request\",\n                \"step\",\n                \"reset\",\n                \"json_metrics\",\n                \"get_request_stream_callback\",\n                \"set_request_stream_callback\",\n                \"create_request\",\n            ],\n        )\n        self.trace_recorder = EventTraceRecorder() if enable_tracing else None\n\n        engine_config.model = model_args[0][0]\n        engine_config.model_lib = model_args[0][1]\n        engine_config.additional_models = model_args[1:]  # type: ignore\n        engine_config.mode = mode\n        self._ffi[\"init\"](\n            engine_config.asjson(),\n            device,\n            request_stream_callback,\n            self.trace_recorder,\n        )\n        self.tokenizer = Tokenizer(model_args[0][0])\n\n    def generate(  # pylint: disable=too-many-locals\n        self,\n        prompts: Union[str, List[str], List[int], List[List[int]], List[List[data.Data]]],\n        generation_config: Union[GenerationConfig, List[GenerationConfig]],\n    ) -> Tuple[List[List[str]], List[Optional[List[List[str]]]]]:\n        \"\"\"Generate texts for a list of input prompts.\n        Each prompt can be a string or a list of token ids.\n        The generation for each prompt is independent.\n        Return the generation results, one for each prompt.\n\n        Parameters\n        ----------\n        prompts : Union[str, List[str], List[int], List[List[int]]]\n            One or a list of input prompts for text generation.\n            Each prompt can be a string or a list of token ids.\n\n        generation_config : Union[GenerationConfig, List[GenerationConfig]]\n            The generation config for each requests.\n            If the it is a single GenerationConfig instance,\n            this config will be shared by all the prompts.\n            Otherwise, one generation config is required for every\n            prompt.\n\n        Returns\n        -------\n        output_text : List[List[str]]\n            The text generation results, one list of strings for each input prompt.\n            The length of each list is the parallel generation `n` in\n            generation config.\n\n        output_logprobs_str : List[Optional[List[List[str]]]]\n            The logprob strings of each token for each input prompt, or None\n            if an input prompt does not require logprobs.\n        \"\"\"\n        if isinstance(prompts, str):\n            # `prompts` is a single string.\n            prompts = [prompts]\n        else:\n            assert isinstance(prompts, list), (\n                \"Input `prompts` is expected to be a string, a list of \"\n                \"str, a list of token ids or multiple lists of token ids. \"\n            )\n            if len(prompts) == 0:\n                return [], []\n            if isinstance(prompts[0], int):\n                # `prompts` is a list of token ids\n                prompts = [prompts]  # type: ignore\n\n        num_requests = len(prompts)\n        if not isinstance(generation_config, list):\n            generation_config = [generation_config] * num_requests\n\n        assert (\n            len(generation_config) == num_requests\n        ), \"Number of generation config and number of prompts mismatch\"\n\n        num_finished_generations = 0\n        output_texts: List[List[str]] = []\n        output_logprobs_str: List[Optional[List[List[str]]]] = []\n        text_streamers: List[List[TextStreamer]] = []\n        for i in range(num_requests):\n            output_texts.append([])\n            output_logprobs_str.append([] if generation_config[i].logprobs else None)\n            text_streamers.append([])\n            for _ in range(generation_config[i].n):\n                output_texts[i].append(\"\")\n                text_streamers[i].append(TextStreamer(self.tokenizer))\n                if output_logprobs_str[i] is not None:\n                    output_logprobs_str[i].append([])\n\n        num_total_generations = sum(cfg.n for cfg in generation_config)\n\n        # Save a copy of the original function callback since `generate`\n        # overrides the callback function.\n        # The original callback will be set back later on.\n        original_callback = self._ffi[\"get_request_stream_callback\"]()\n\n        # Define the callback function for request generation results\n        def request_stream_callback(delta_outputs: List[data.RequestStreamOutput]):\n            nonlocal num_finished_generations\n            for delta_output in delta_outputs:\n                request_id, stream_outputs = delta_output.unpack()\n                rid = int(request_id)\n\n                assert len(stream_outputs) == generation_config[rid].n  # type: ignore\n                for i, (stream_output, text_streamer) in enumerate(\n                    zip(stream_outputs, text_streamers[rid])\n                ):\n                    if output_logprobs_str[rid] is not None:\n                        assert stream_output.delta_logprob_json_strs is not None\n                        output_logprobs_str[rid][i] += stream_output.delta_logprob_json_strs\n\n                    delta_text = stream_output.extra_prefix_string + (\n                        text_streamer.put(stream_output.delta_token_ids)\n                        if len(stream_output.delta_token_ids) > 0\n                        else \"\"\n                    )\n                    if stream_output.finish_reason is not None:\n                        delta_text += text_streamer.finish()\n\n                    output_texts[rid][i] += delta_text\n                    if stream_output.finish_reason is not None:\n                        num_finished_generations += 1\n\n        # Override the callback function in engine.\n        self._ffi[\"set_request_stream_callback\"](request_stream_callback)\n\n        def convert_to_data(\n            prompt: Union[str, List[int], List[data.Data]],\n        ) -> List[data.Data]:\n            if isinstance(prompt, str):\n                return [data.TextData(prompt)]\n            if isinstance(prompt[0], int):\n                return [data.TokenData(prompt)]  # type: ignore\n            return prompt  # type: ignore\n\n        # Add requests to engine.\n        for req_id, (prompt, generation_cfg) in enumerate(zip(prompts, generation_config)):\n            input_data = convert_to_data(prompt)  # type: ignore\n            self.add_request(\n                self.create_request(\n                    request_id=str(req_id),\n                    inputs=input_data,\n                    generation_config=generation_cfg,\n                )\n            )\n\n        while num_finished_generations != num_total_generations:\n            self.step()\n\n        # Restore the callback function in engine.\n        self._ffi[\"set_request_stream_callback\"](original_callback)\n        return output_texts, output_logprobs_str\n\n    def create_request(\n        self,\n        request_id: str,\n        inputs: Union[data.Data, List[data.Data]],\n        generation_config: GenerationConfig,\n    ):\n        \"\"\"Create a new request that can be added to engine.\n\n        Parameters\n        ----------\n        request_id : str\n            The unique identifier of the request.\n            Different requests should have different ids.\n\n        inputs : List[Data]\n            The user inputs of a request. Input may have multi-modality.\n\n        generation_config : GenerationConfig\n            The generation configuration of the request.\n\n        Note\n        ----\n        engine may fill in default generation config of the model.\n        \"\"\"\n        if not isinstance(inputs, list):\n            inputs = [inputs]\n        return self._ffi[\"create_request\"](\n            request_id, inputs, generation_config.model_dump_json(by_alias=True)\n        )\n\n    def add_request(self, request: Request) -> None:\n        \"\"\"Add a new request to the engine.\n\n        Parameters\n        ----------\n        request : Request\n            The request to add.\n        \"\"\"\n        self._ffi[\"add_request\"](request)\n\n    def abort_request(self, request_id: str) -> None:\n        \"\"\"Abort the generation of the request corresponding to the input request id.\n\n        Parameters\n        ----------\n        request_id : str\n            The unique id of the request to abort.\n        \"\"\"\n        self._ffi[\"abort_request\"](request_id)\n\n    def step(self) -> None:\n        \"\"\"The main function that the engine takes a step of action.\n\n        At each step, the engine may decide to\n        - run prefill for one (or more) requests,\n        - run one-step decode for the all existing requests\n        ...\n\n        In the end of certain actions (e.g., decode), the engine will\n        check if any request has finished, and will return the\n        generation results for those finished requests.\n        \"\"\"\n        self._ffi[\"step\"]()\n\n    def reset(self) -> None:\n        \"\"\"Reset the engine, clean up all running data and metrics.\"\"\"\n        self._ffi[\"reset\"]()\n\n    def metrics(self) -> EngineMetrics:\n        \"\"\"Reset the engine, clean up all running data and metrics.\"\"\"\n        return EngineMetrics(json.loads(self._ffi[\"json_metrics\"]()))\n"
  },
  {
    "path": "python/mlc_llm/support/__init__.py",
    "content": "\"\"\"\nCommon utilities used in the Python package. Do not import anything by default,\nas they may introduce unnecessary dependencies.\n\"\"\"\n"
  },
  {
    "path": "python/mlc_llm/support/argparse.py",
    "content": "\"\"\"An enhanced argument parser for mlc-chat.\"\"\"\n\nimport argparse\nimport sys\n\n\nclass ArgumentParser(argparse.ArgumentParser):\n    \"\"\"An enhanced argument parser for mlc-chat.\"\"\"\n\n    def error(self, message):\n        \"\"\"Overrides the behavior when erroring out\"\"\"\n        print(\"-\" * 25 + \" Usage \" + \"-\" * 25)\n        self.print_help()\n        print(\"-\" * 25 + \" Error \" + \"-\" * 25)\n        print(message, file=sys.stderr)\n        sys.exit(2)\n"
  },
  {
    "path": "python/mlc_llm/support/auto_config.py",
    "content": "\"\"\"Help function for detecting the model configuration file `config.json`\"\"\"\n\nimport json\nimport tempfile\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING\n\nfrom . import logging\nfrom .style import bold, green\n\nif TYPE_CHECKING:\n    from mlc_llm.model import Model  # pylint: disable=unused-import\n    from mlc_llm.quantization import Quantization  # pylint: disable=unused-import\n\n\nlogger = logging.getLogger(__name__)\n\nFOUND = green(\"Found\")\n\n\ndef detect_mlc_chat_config(mlc_chat_config: str) -> Path:\n    \"\"\"Detect and return the path that points to mlc-chat-config.json.\n    If `mlc_chat_config` is a directory, it looks for mlc-chat-config.json below it.\n\n    Parameters\n    ---------\n    mlc_chat_config : str\n        The path to `mlc-chat-config.json`, or the directory containing\n        `mlc-chat-config.json`.\n\n    Returns\n    -------\n    mlc_chat_config_json_path : pathlib.Path\n        The path points to mlc_chat_config.json.\n    \"\"\"\n    # pylint: disable=import-outside-toplevel\n    from mlc_llm.model import MODEL_PRESETS\n\n    from .download_cache import download_and_cache_mlc_weights\n\n    # pylint: enable=import-outside-toplevel\n\n    if mlc_chat_config.startswith(\"HF://\") or mlc_chat_config.startswith(\"http\"):\n        mlc_chat_config_path = Path(download_and_cache_mlc_weights(model_url=mlc_chat_config))\n    elif isinstance(mlc_chat_config, str) and mlc_chat_config in MODEL_PRESETS:\n        logger.info(\"%s mlc preset model: %s\", FOUND, mlc_chat_config)\n        content = MODEL_PRESETS[mlc_chat_config].copy()\n        content[\"model_preset_tag\"] = mlc_chat_config\n        temp_file = tempfile.NamedTemporaryFile(  # pylint: disable=consider-using-with\n            suffix=\".json\",\n            delete=False,\n        )\n        logger.info(\"Dumping config to: %s\", temp_file.name)\n        mlc_chat_config_path = Path(temp_file.name)\n        with mlc_chat_config_path.open(\"w\", encoding=\"utf-8\") as mlc_chat_config_file:\n            json.dump(content, mlc_chat_config_file, indent=2)\n    else:\n        mlc_chat_config_path = Path(mlc_chat_config)\n    if not mlc_chat_config_path.exists():\n        raise ValueError(f\"{mlc_chat_config_path} does not exist.\")\n\n    if mlc_chat_config_path.is_dir():\n        # search mlc-chat-config.json under path\n        mlc_chat_config_json_path = mlc_chat_config_path / \"mlc-chat-config.json\"\n        if not mlc_chat_config_json_path.exists():\n            raise ValueError(f\"Fail to find mlc-chat-config.json under {mlc_chat_config_path}.\")\n    else:\n        mlc_chat_config_json_path = mlc_chat_config_path\n\n    logger.info(\"%s model configuration: %s\", FOUND, mlc_chat_config_json_path)\n    return mlc_chat_config_json_path\n\n\ndef detect_config(config: str) -> Path:\n    \"\"\"Detect and return the path that points to config.json. If `config` is a directory,\n    it looks for config.json below it.\n\n    Parameters\n    ---------\n    config : str\n        The preset name of the model, or the path to `config.json`, or the directory containing\n        `config.json`.\n\n    Returns\n    -------\n    config_json_path : pathlib.Path\n        The path points to config.json.\n    \"\"\"\n    from mlc_llm.model import MODEL_PRESETS  # pylint: disable=import-outside-toplevel\n\n    if isinstance(config, str) and config in MODEL_PRESETS:\n        logger.info(\"%s preset model: %s\", FOUND, config)\n        content = MODEL_PRESETS[config].copy()\n        content[\"model_preset_tag\"] = config\n        temp_file = tempfile.NamedTemporaryFile(  # pylint: disable=consider-using-with\n            suffix=\".json\",\n            delete=False,\n        )\n        logger.info(\"Dumping config to: %s\", temp_file.name)\n        config_path = Path(temp_file.name)\n        with config_path.open(\"w\", encoding=\"utf-8\") as config_file:\n            json.dump(content, config_file, indent=2)\n    else:\n        config_path = Path(config)\n    if not config_path.exists():\n        raise ValueError(f\"{config_path} does not exist.\")\n\n    if config_path.is_dir():\n        # search config.json under config path\n        config_json_path = config_path / \"config.json\"\n        if not config_json_path.exists():\n            raise ValueError(f\"Fail to find config.json under {config_path}.\")\n    else:\n        config_json_path = config_path\n\n    logger.info(\"%s model configuration: %s\", FOUND, config_json_path)\n    return config_json_path\n\n\ndef detect_model_type(model_type: str, config: Path) -> \"Model\":\n    \"\"\"Detect the model type from the configuration file. If `model_type` is \"auto\", it will be\n    inferred from the configuration file. Otherwise, it will be used as the model type, and sanity\n    check will be performed.\n\n    Parameters\n    ----------\n    model_type : str\n        The model type, for example, \"llama\".\n\n    config : pathlib.Path\n        The path to config.json.\n\n    Returns\n    -------\n    model : mlc_llm.compiler.Model\n        The model type.\n    \"\"\"\n\n    from mlc_llm.model import MODELS  # pylint: disable=import-outside-toplevel\n\n    if model_type == \"auto\":\n        with open(config, \"r\", encoding=\"utf-8\") as config_file:\n            cfg = json.load(config_file)\n        if \"model_type\" not in cfg and (\n            \"model_config\" not in cfg or \"model_type\" not in cfg[\"model_config\"]\n        ):\n            raise ValueError(\n                f\"'model_type' not found in: {config}. \"\n                f\"Please explicitly specify `--model-type` instead.\"\n            )\n        model_type = cfg[\"model_type\"] if \"model_type\" in cfg else cfg[\"model_config\"][\"model_type\"]\n    if model_type in [\"mixformer-sequential\"]:\n        model_type = \"phi-msft\"\n    logger.info(\"%s model type: %s. Use `--model-type` to override.\", FOUND, bold(model_type))\n    if model_type not in MODELS:\n        raise ValueError(f\"Unknown model type: {model_type}. Available ones: {list(MODELS.keys())}\")\n    return MODELS[model_type]\n\n\ndef detect_quantization(quantization_arg: str, config: Path) -> \"Quantization\":\n    \"\"\"Detect the model quantization scheme from the configuration file or `--quantization`\n    argument. If `--quantization` is provided, it will override the value on the configuration\n    file.\n\n    Parameters\n    ----------\n    quantization_arg : str\n        The quantization scheme, for example, \"q4f16_1\".\n\n    config : pathlib.Path\n        The path to mlc-chat-config.json.\n\n    Returns\n    -------\n    quantization : mlc_llm.quantization.Quantization\n        The model quantization scheme.\n    \"\"\"\n    from mlc_llm.quantization import (  # pylint: disable=import-outside-toplevel\n        QUANTIZATION,\n    )\n\n    with open(config, \"r\", encoding=\"utf-8\") as config_file:\n        cfg = json.load(config_file)\n    if quantization_arg is not None:\n        quantization = QUANTIZATION[quantization_arg]\n    elif \"quantization\" in cfg:\n        quantization = QUANTIZATION[cfg[\"quantization\"]]\n    else:\n        raise ValueError(\n            f\"'quantization' not found in: {config}. \"\n            f\"Please explicitly specify `--quantization` instead.\"\n        )\n    return quantization\n"
  },
  {
    "path": "python/mlc_llm/support/auto_device.py",
    "content": "\"\"\"Automatic detection of the device available on the local machine.\"\"\"\n\nimport os\nimport subprocess\nimport sys\nfrom typing import Dict, Optional\n\nimport tvm\nfrom tvm.runtime import Device\nfrom tvm_ffi import DLDeviceType\n\nfrom . import logging\nfrom .style import bold, green, red\n\nFOUND = green(\"Found\")\nNOT_FOUND = red(\"Not found\")\nAUTO_DETECT_DEVICES = [\"cuda\", \"rocm\", \"metal\", \"vulkan\", \"opencl\", \"cpu\"]\n_RESULT_CACHE: Dict[str, bool] = {}\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef detect_device(device_hint: str) -> Optional[Device]:\n    \"\"\"Detect locally available device from string hint.\"\"\"\n    if device_hint == \"auto\":\n        device = None\n        for device_type in AUTO_DETECT_DEVICES:\n            cur_device = tvm.device(device_type=device_type, index=0)\n            if _device_exists(cur_device):\n                if device is None:\n                    device = cur_device\n        if device is None:\n            logger.info(\"%s: No available device detected\", NOT_FOUND)\n            return None\n        logger.info(\"Using device: %s\", bold(device2str(device)))\n        return device\n    try:\n        device = tvm.device(device_hint)\n    except Exception as err:\n        raise ValueError(f\"Invalid device name: {device_hint}\") from err\n    if not _device_exists(device):\n        raise ValueError(f\"Device is not found on your local environment: {device_hint}\")\n    return device\n\n\ndef device2str(device: Device) -> str:\n    \"\"\"Convert a TVM device object to string.\"\"\"\n    return f\"{tvm.runtime.Device._DEVICE_TYPE_TO_NAME[device.dlpack_device_type()]}:{device.index}\"  # pylint: disable=protected-access, line-too-long\n\n\ndef _device_exists(device: Device) -> bool:\n    device_type = tvm.runtime.Device._DEVICE_TYPE_TO_NAME[  # pylint: disable=protected-access\n        device.dlpack_device_type()\n    ]\n    device_str = device2str(device)\n    if device_str in _RESULT_CACHE:\n        return _RESULT_CACHE[device_str]\n    cmd = [\n        sys.executable,\n        \"-m\",\n        \"mlc_llm.cli.check_device\",\n        device_type,\n    ]\n    prefix = \"check_device:\"\n    subproc_outputs = [\n        line[len(prefix) :].strip()\n        for line in subprocess.run(\n            cmd,\n            capture_output=True,\n            text=True,\n            check=False,\n            env=os.environ,\n        )\n        .stdout.strip()\n        .splitlines()\n        if line.startswith(prefix)\n    ]\n    if subproc_outputs:\n        if subproc_outputs[0]:\n            for i in subproc_outputs[0].split(\",\"):\n                logger.info(\"%s device: %s:%s\", FOUND, device_type, i)\n                _RESULT_CACHE[f\"{device_type}:{i}\"] = True\n                if device.dlpack_device_type() == DLDeviceType.kDLCPU:\n                    break\n    else:\n        logger.error(\n            \"GPU device detection failed. Please report this issue with the output of command: %s\",\n            \" \".join(cmd),\n        )\n    if device_str in _RESULT_CACHE:\n        return _RESULT_CACHE[device_str]\n    logger.info(\"%s device: %s\", NOT_FOUND, device_str)\n    _RESULT_CACHE[device_str] = False\n    return False\n"
  },
  {
    "path": "python/mlc_llm/support/auto_target.py",
    "content": "\"\"\"Helper functions for target auto-detection.\"\"\"\n\nimport os\nfrom pathlib import Path\nfrom typing import TYPE_CHECKING, Callable, List, Optional, Tuple\n\nfrom tvm import IRModule, relax\nfrom tvm.contrib import ndk, tar, xcode\nfrom tvm.ir.transform import Pass\nfrom tvm.target import Target\nfrom tvm_ffi import get_global_func, register_global_func\n\nfrom . import logging\nfrom .auto_device import AUTO_DETECT_DEVICES, detect_device, device2str\nfrom .constants import MLC_MULTI_ARCH\nfrom .style import bold, green, red\n\nif TYPE_CHECKING:\n    from mlc_llm.compiler.compile import CompileArgs\n\n\nlogger = logging.getLogger(__name__)\n\n# TODO: add help message on how to specify the target manually # pylint: disable=fixme\nHELP_MSG = \"\"\"TBD\"\"\"\nFOUND = green(\"Found\")\nNOT_FOUND = red(\"Not found\")\nBuildFunc = Callable[[IRModule, \"CompileArgs\", Pass], None]\n\n\ndef detect_target_and_host(target_hint: str, host_hint: str = \"auto\") -> Tuple[Target, BuildFunc]:\n    \"\"\"Detect the configuration for the target device and its host, for example, target GPU and\n    the host CPU.\n\n    Parameters\n    ----------\n    target_hint : str\n        The hint for the target device.\n\n    host_hint : str\n        The hint for the host CPU, default is \"auto\".\n    \"\"\"\n    target, build_func = _detect_target_gpu(target_hint)\n    if target.host is None:\n        target = Target(target, host=_detect_target_host(host_hint))\n    if target.kind.name == \"cuda\":\n        # Enable thrust for CUDA\n        target_dict = dict(target.export())\n        target_dict[\"libs\"] = (\n            (target_dict[\"libs\"] + [\"thrust\"]) if \"libs\" in target_dict else [\"thrust\"]\n        )\n        target = Target(target_dict)\n        _register_cuda_hook(target)\n    elif target.kind.name == \"rocm\":\n        target_dict = dict(target.export())\n        extra_libs = [\"thrust\", \"rocblas\", \"miopen\", \"hipblas\"]\n        target_dict[\"libs\"] = (\n            (target_dict[\"libs\"] + extra_libs) if \"libs\" in target_dict else extra_libs\n        )\n        target = Target(target_dict)\n    return target, build_func\n\n\ndef _detect_target_gpu(hint: str) -> Tuple[Target, BuildFunc]:\n    if hint in [\"iphone\", \"macabi\", \"android\", \"webgpu\", \"mali\", \"opencl\"]:\n        hint += \":generic\"\n    if hint == \"auto\" or hint in AUTO_DETECT_DEVICES:\n        target: Optional[Target] = None\n        device = detect_device(hint)\n        if device is not None:\n            device_str = device2str(device)\n            try:\n                target = Target.from_device(device)\n            except ValueError:\n                logger.info(\"%s: Cannot detect target from device: %s\", NOT_FOUND, device_str)\n        if target is None:\n            raise ValueError(f\"No target detected from device: {hint}. Please specify explicitly\")\n        logger.info(\n            '%s configuration of target device \"%s\": %s',\n            FOUND,\n            bold(device_str),\n            target.export(),\n        )\n        return target, _build_default()\n    if hint in PRESET:\n        preset = PRESET[hint]\n        target = Target(preset[\"target\"])  # type: ignore[index]\n        build = preset.get(\"build\", _build_default)  # type: ignore[attr-defined]\n        return target, build()\n    if _is_device(hint):\n        logger.info(\"Detecting target device: %s\", hint)\n        target = Target.from_device(hint)\n        logger.info(\"%s target: %s\", FOUND, target.export())\n        return target, _build_default()\n    try:\n        logger.info(\"Try creating device target from string: %s\", hint)\n        target = Target(hint)\n        logger.info(\"%s target: %s\", FOUND, target.export())\n        return target, _build_default()\n    except Exception as err:\n        logger.info(\"%s: Failed to create target\", NOT_FOUND)\n        raise ValueError(f\"Invalid target: {hint}\") from err\n\n\ndef _detect_target_host(hint: str) -> Target:\n    \"\"\"Detect the host CPU architecture.\"\"\"\n    if hint == \"auto\":\n        target_triple = get_global_func(\"tvm.codegen.llvm.GetDefaultTargetTriple\")()\n        target = Target.from_device(\"cpu\")\n        logger.info(\"%s host LLVM triple: %s\", FOUND, bold(target.attrs[\"mtriple\"]))\n        logger.info(\"%s host LLVM CPU: %s\", FOUND, bold(target.attrs[\"mcpu\"]))\n        return target\n    target_triple = hint\n    logger.info(\"Using LLVM triple specified by --host: %s\", bold(target_triple))\n    return Target({\"kind\": \"llvm\", \"mtriple\": target_triple})\n\n\ndef _is_device(device: str):\n    if \" \" in device:\n        return False\n    if device.count(\":\") != 1:\n        return False\n    return True\n\n\ndef _add_system_lib_prefix(mod: IRModule, prefix: str, is_system_lib: bool) -> IRModule:\n    if is_system_lib and prefix:\n        mod = mod.with_attrs({\"system_lib_prefix\": prefix})  # type: ignore[dict-item]\n    elif is_system_lib:\n        logger.warning(\n            \"%s is not specified when building a static library\",\n            bold(\"--system-lib-prefix\"),\n        )\n    elif prefix:\n        logger.warning(\n            \"--system-lib-prefix is specified, but it will not take any effect \"\n            \"when building the shared library\"\n        )\n    return mod\n\n\ndef _build_metal_x86_64():\n    def build(mod: IRModule, args: \"CompileArgs\", pipeline=None):\n        output = args.output\n        mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=False)\n        assert output.suffix == \".dylib\"\n        relax.build(\n            mod,\n            target=args.target,\n            relax_pipeline=pipeline,\n        ).export_library(\n            str(output),\n            fcompile=xcode.create_dylib,\n            sdk=\"macosx\",\n            arch=\"x86_64\",\n        )\n\n    return build\n\n\ndef _build_iphone():\n    @register_global_func(\"tvm_callback_metal_compile\", override=True)\n    def compile_metal(src, target):\n        libs = target.attrs.get(\"libs\", None)\n        if libs:\n            return xcode.compile_metal(src, sdk=libs[0])\n        return xcode.compile_metal(src)\n\n    def build(mod: IRModule, args: \"CompileArgs\", pipeline=None):\n        output = args.output\n        mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=True)\n        assert output.suffix == \".tar\"\n        relax.build(\n            mod,\n            target=args.target,\n            relax_pipeline=pipeline,\n            system_lib=True,\n        ).export_library(\n            str(output),\n            fcompile=tar.tar,\n        )\n\n    return build\n\n\ndef _build_android():\n    def build(mod: IRModule, args: \"CompileArgs\", pipeline=None):\n        output = args.output\n        mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=True)\n        assert output.suffix == \".tar\"\n        ex = relax.build(\n            mod,\n            target=args.target,\n            relax_pipeline=pipeline,\n            system_lib=True,\n        )\n        ex.export_library(\n            str(output),\n            fcompile=tar.tar,\n        )\n        if args.debug_dump is not None:\n            source = ex.mod.imports[0].imports[0].inspect_source()\n            with open(args.debug_dump / \"kernel.cl\", \"w\", encoding=\"utf-8\") as f:\n                f.write(source)\n\n    return build\n\n\ndef _build_android_so():\n    def build(mod: IRModule, args: \"CompileArgs\", pipeline=None):\n        output = args.output\n        mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=False)\n        assert output.suffix == \".so\"\n        ex = relax.build(\n            mod,\n            target=args.target,\n            relax_pipeline=pipeline,\n            system_lib=False,\n        )\n        ex.export_library(\n            str(output),\n            fcompile=ndk.create_shared,\n        )\n        if args.debug_dump is not None:\n            source = ex.mod.imports[0].imports[0].inspect_source()\n            with open(args.debug_dump / \"kernel.cl\", \"w\", encoding=\"utf-8\") as f:\n                f.write(source)\n\n    return build\n\n\ndef _build_webgpu():\n    def build(mod: IRModule, args: \"CompileArgs\", pipeline=None):\n        output = args.output\n        mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=True)\n        assert output.suffix == \".wasm\"\n\n        # Try to locate `mlc_wasm_runtime.bc`\n        bc_path = None\n        bc_candidates = [\"web/dist/wasm/mlc_wasm_runtime.bc\"]\n        if os.environ.get(\"MLC_LLM_SOURCE_DIR\", None):\n            mlc_source_home_dir = os.environ[\"MLC_LLM_SOURCE_DIR\"]\n            bc_candidates.append(\n                os.path.join(mlc_source_home_dir, \"web\", \"dist\", \"wasm\", \"mlc_wasm_runtime.bc\")\n            )\n        error_info = (\n            \"Cannot find library: mlc_wasm_runtime.bc\\n\"\n            + \"Make sure you have run `./web/prep_emcc_deps.sh` and \"\n            + \"`export MLC_LLM_SOURCE_DIR=/path/to/mlc-llm` so that we can locate the file. \"\n            + \"We tried to look at candidate paths:\\n\"\n        )\n        for candidate in bc_candidates:\n            error_info += candidate + \"\\n\"\n            if Path(candidate).exists():\n                bc_path = candidate\n        if not bc_path:\n            raise RuntimeError(error_info)\n\n        relax.build(\n            mod,\n            target=args.target,\n            relax_pipeline=pipeline,\n            system_lib=True,\n        ).export_library(\n            str(output),\n            libs=[bc_path],\n        )\n\n    return build\n\n\ndef _build_mali():\n    def build(mod: IRModule, args: \"CompileArgs\", pipeline=None):\n        output = args.output\n        mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=True)\n        assert output.suffix == \".so\"\n        mod = relax.build(\n            mod,\n            target=args.target,\n            relax_pipeline=pipeline,\n            system_lib=True,\n        )\n        if \"TVM_NDK_CC\" in os.environ:\n            mod.export_library(str(output), fcompile=ndk.create_shared)\n        else:\n            mod.export_library(str(output))\n\n    return build\n\n\ndef _build_default():\n    def build(mod: IRModule, args: \"CompileArgs\", pipeline=None):\n        output = args.output\n        if output.suffix in [\".tar\", \".lib\"]:\n            system_lib = True\n        elif output.suffix in [\".so\", \".dylib\", \".dll\"]:\n            system_lib = False\n        else:\n            logger.warning(\"Unknown output suffix: %s. Assuming shared library.\", output.suffix)\n            system_lib = False\n        mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=system_lib)\n        relax.build(\n            mod,\n            target=args.target,\n            relax_pipeline=pipeline,\n            system_lib=system_lib,\n        ).export_library(\n            str(output),\n        )\n\n    return build\n\n\ndef detect_cuda_arch_list(target: Target) -> List[int]:\n    \"\"\"Detect the CUDA architecture list from the target.\"\"\"\n\n    def convert_to_num(arch_str):\n        arch_num_str = \"\".join(filter(str.isdigit, arch_str))\n        assert arch_num_str, f\"'{arch_str}' does not contain any digits\"\n        return int(arch_num_str)\n\n    assert target.kind.name == \"cuda\", f\"Expect target to be CUDA, but got {target}\"\n    if MLC_MULTI_ARCH is not None:\n        multi_arch = [convert_to_num(x) for x in MLC_MULTI_ARCH.split(\",\")]\n    else:\n        assert target.attrs.get(\"arch\", \"\").startswith(\"sm_\")\n        multi_arch = [convert_to_num(target.attrs.get(\"arch\")[3:])]\n    multi_arch = list(set(multi_arch))\n    return multi_arch\n\n\ndef _register_cuda_hook(target: Target):\n    if MLC_MULTI_ARCH is None:\n        default_arch = target.attrs.get(\"arch\", None)\n        logger.info(\"Generating code for CUDA architecture: %s\", bold(default_arch))\n        logger.info(\n            \"To produce multi-arch fatbin, set environment variable %s. \"\n            \"Example: MLC_MULTI_ARCH=70,72,75,80,86,87,89,90a\",\n            bold(\"MLC_MULTI_ARCH\"),\n        )\n        multi_arch = None\n    else:\n        logger.info(\"%s %s: %s\", FOUND, bold(\"MLC_MULTI_ARCH\"), MLC_MULTI_ARCH)\n        multi_arch = [x.strip() for x in MLC_MULTI_ARCH.split(\",\")]\n        logger.info(\"Generating code for CUDA architecture: %s\", multi_arch)\n\n    @register_global_func(\"tvm_callback_cuda_compile\", override=True)\n    def tvm_callback_cuda_compile(code, target):  # pylint: disable=unused-argument\n        \"\"\"use nvcc to generate fatbin code for better optimization\"\"\"\n        from tvm.contrib import nvcc  # pylint: disable=import-outside-toplevel\n\n        if multi_arch is None:\n            ptx = nvcc.compile_cuda(code, target_format=\"fatbin\")\n        else:\n            arch = []\n            for compute_version in multi_arch:\n                arch += [\n                    \"-gencode\",\n                    f\"arch=compute_{compute_version},code=sm_{compute_version}\",\n                ]\n            ptx = nvcc.compile_cuda(code, target_format=\"fatbin\", arch=arch)\n        return ptx\n\n\ndef detect_system_lib_prefix(\n    target_hint: str, prefix_hint: str, model_name: str, quantization: str\n) -> str:\n    \"\"\"Detect the iOS / Android system lib prefix to identify the library needed to load the app.\n\n    Parameters\n    ----------\n    target_hint : str\n        The hint for the target device.\n\n    prefix_hint : str\n        The hint for the system lib prefix.\n    \"\"\"\n    if prefix_hint == \"auto\" and (\n        target_hint.startswith(\"iphone\")\n        or target_hint.startswith(\"macabi\")\n        or target_hint.startswith(\"android\")\n    ):\n        prefix = f\"{model_name}_{quantization}_\".replace(\"-\", \"_\")\n        logger.warning(\n            \"%s is automatically picked from the filename, %s, this allows us to use the filename \"\n            \"as the model_lib in android/iOS builds. Please avoid renaming the .tar file when \"\n            \"uploading the prebuilt.\",\n            bold(\"--system-lib-prefix\"),\n            bold(prefix),\n        )\n        return prefix\n    if target_hint not in [\"iphone\", \"macabi\", \"android\"]:\n        return \"\"\n    return prefix_hint\n\n\n_MACABI_ARCH = os.environ.get(\"MLC_MACABI_ARCH\", \"\").strip() or \"arm64\"\nif _MACABI_ARCH not in [\"arm64\", \"x86_64\"]:\n    _MACABI_ARCH = \"arm64\"\n_MACABI_MTRIPLE = (\n    \"x86_64-apple-ios18.0-macabi\" if _MACABI_ARCH == \"x86_64\" else \"arm64-apple-ios18.0-macabi\"\n)\n\nPRESET = {\n    \"iphone:generic\": {\n        \"target\": {\n            \"kind\": \"metal\",\n            \"max_threads_per_block\": 256,\n            \"max_shared_memory_per_block\": 32768,\n            \"thread_warp_size\": 1,\n            \"libs\": [\"iphoneos\"],\n            \"host\": {\n                \"kind\": \"llvm\",\n                \"mtriple\": \"arm64-apple-darwin\",\n            },\n        },\n        \"build\": _build_iphone,\n    },\n    \"macabi:generic\": {\n        \"target\": {\n            \"kind\": \"metal\",\n            \"max_threads_per_block\": 256,\n            \"max_shared_memory_per_block\": 32768,\n            \"thread_warp_size\": 1,\n            \"libs\": [\"macosx\"],\n            \"host\": {\n                \"kind\": \"llvm\",\n                \"mtriple\": _MACABI_MTRIPLE,\n            },\n        },\n        \"build\": _build_iphone,\n    },\n    \"android:generic\": {\n        \"target\": {\n            \"kind\": \"opencl\",\n            \"host\": {\n                \"kind\": \"llvm\",\n                \"mtriple\": \"aarch64-linux-android\",\n            },\n        },\n        \"build\": _build_android,\n    },\n    \"android:adreno\": {\n        \"target\": {\n            \"kind\": \"opencl\",\n            \"device\": \"adreno\",\n            \"max_threads_per_block\": 512,\n            \"host\": {\n                \"kind\": \"llvm\",\n                \"mtriple\": \"aarch64-linux-android\",\n            },\n        },\n        \"build\": _build_android,\n    },\n    \"android:adreno-so\": {\n        \"target\": {\n            \"kind\": \"opencl\",\n            \"device\": \"adreno\",\n            \"max_threads_per_block\": 512,\n            \"host\": {\n                \"kind\": \"llvm\",\n                \"mtriple\": \"aarch64-linux-android\",\n            },\n        },\n        \"build\": _build_android_so,\n    },\n    \"windows:adreno_x86\": {\n        \"target\": {\n            \"kind\": \"opencl\",\n            \"device\": \"adreno\",\n            \"max_threads_per_block\": 512,\n            \"host\": {\n                \"kind\": \"llvm\",\n                \"mtriple\": \"x86_64-pc-windows-msvc\",\n            },\n        },\n    },\n    \"metal:x86-64\": {\n        \"target\": {\n            \"kind\": \"metal\",\n            \"max_threads_per_block\": 256,\n            \"max_shared_memory_per_block\": 32768,\n            \"thread_warp_size\": 1,\n        },\n        \"build\": _build_metal_x86_64,\n    },\n    \"webgpu:generic\": {\n        \"target\": {\n            \"kind\": \"webgpu\",\n            \"host\": {\n                \"kind\": \"llvm\",\n                \"mtriple\": \"wasm32-unknown-unknown-wasm\",\n            },\n        },\n        \"build\": _build_webgpu,\n    },\n    \"opencl:generic\": {\n        \"target\": {\n            \"kind\": \"opencl\",\n        },\n    },\n    \"mali:generic\": {\n        \"target\": {\n            \"kind\": \"opencl\",\n            \"host\": {\n                \"kind\": \"llvm\",\n                \"mtriple\": \"aarch64-linux-gnu\",\n            },\n        },\n        \"build\": _build_mali,\n    },\n    \"metal:generic\": {\n        \"target\": {\n            \"kind\": \"metal\",\n            \"max_threads_per_block\": 256,\n            \"max_shared_memory_per_block\": 32768,\n            \"thread_warp_size\": 1,\n        },\n    },\n    \"vulkan:generic\": {\n        \"target\": {\n            \"kind\": \"vulkan\",\n            \"max_threads_per_block\": 256,\n            \"max_shared_memory_per_block\": 32768,\n            \"thread_warp_size\": 1,\n            \"supports_float16\": 1,\n            \"supports_int64\": 1,\n            \"supports_int16\": 1,\n            \"supports_int8\": 1,\n            \"supports_8bit_buffer\": 1,\n            \"supports_16bit_buffer\": 1,\n            \"supports_storage_buffer_storage_class\": 1,\n        },\n    },\n}\n"
  },
  {
    "path": "python/mlc_llm/support/auto_weight.py",
    "content": "\"\"\"Help functions for detecting weight paths and weight formats.\"\"\"\n\nimport json\nfrom pathlib import Path\nfrom typing import List, Optional, Tuple\n\nfrom . import logging\nfrom .style import bold, green, red\n\nlogger = logging.getLogger(__name__)\n\nFOUND = green(\"Found\")\nNOT_FOUND = red(\"Not found\")\n\n\ndef detect_weight(\n    weight_path: Path,\n    config_json_path: Path,\n    weight_format: str = \"auto\",\n) -> Tuple[Path, str]:\n    \"\"\"Detect the weight directory, and detect the weight format.\n\n    Parameters\n    ---------\n    weight_path : pathlib.Path\n        The path to weight files. If `weight_path` is not None, check if it exists. Otherwise, find\n        `weight_path` in `config.json` or use the same directory as `config.json`.\n\n    config_json_path: pathlib.Path\n        The path to `config.json`.\n\n    weight_format : str\n        The hint for the weight format. If it is \"auto\", guess the weight format.\n        Otherwise, check the weights are in that format.\n        Available weight formats:\n            - auto (guess the weight format)\n            - huggingface-torch (validate via checking pytorch_model.bin.index.json)\n            - huggingface-safetensor (validate via checking model.safetensors.index.json)\n            - awq\n            - ggml\n            - gguf\n\n    Returns\n    -------\n    weight_config_path : pathlib.Path\n        The path that points to the weights config file or the weights directory.\n\n    weight_format : str\n        The valid weight format.\n    \"\"\"\n    if weight_path is None:\n        assert (\n            config_json_path is not None and config_json_path.exists()\n        ), \"Please provide config.json path.\"\n\n        # 1. Find the weight_path in config.json\n        with open(config_json_path, encoding=\"utf-8\") as i_f:\n            config = json.load(i_f)\n        if \"weight_path\" in config:\n            weight_path = Path(config[\"weight_path\"])\n            logger.info('Found \"weight_path\" in config.json: %s', weight_path)\n            if not weight_path.exists():\n                raise ValueError(f\"weight_path doesn't exist: {weight_path}\")\n        else:\n            # 2. Find the weights file in the same directory as config.json\n            weight_path = config_json_path.parent\n    else:\n        if not weight_path.exists():\n            raise ValueError(f\"weight_path doesn't exist: {weight_path}\")\n\n    logger.info(\"Finding weights in: %s\", weight_path)\n\n    # check weight format\n    # weight_format = \"auto\", guess the weight format.\n    # otherwise, check the weight format is valid.\n    if weight_format == \"auto\":\n        return _guess_weight_format(weight_path)\n\n    if weight_format not in AVAILABLE_WEIGHT_FORMAT:\n        raise ValueError(\n            f\"Available weight format list: {AVAILABLE_WEIGHT_FORMAT}, but got {weight_format}\"\n        )\n    if weight_format in CHECK_FORMAT_METHODS:\n        check_func = CHECK_FORMAT_METHODS[weight_format]\n        weight_config_path = check_func(weight_path)\n        if not weight_config_path:\n            raise ValueError(f\"The weight is not in {weight_format} format.\")\n    else:\n        weight_config_path = weight_path\n    return weight_config_path, weight_format\n\n\ndef _guess_weight_format(weight_path: Path) -> Tuple[Path, str]:\n    possible_formats: List[Tuple[Path, str]] = []\n    for weight_format, check_func in CHECK_FORMAT_METHODS.items():\n        weight_config_path = check_func(weight_path)\n        if weight_config_path:\n            possible_formats.append((weight_config_path, weight_format))\n\n    if len(possible_formats) == 0:\n        raise ValueError(\n            \"Fail to detect source weight format. \"\n            \"Use `--source-format` to explicitly specify the format.\"\n        )\n\n    weight_config_path, selected_format = possible_formats[0]\n    logger.info(\n        \"Using source weight configuration: %s. Use `--source` to override.\",\n        bold(str(weight_config_path)),\n    )\n    logger.info(\n        \"Using source weight format: %s. Use `--source-format` to override.\",\n        bold(selected_format),\n    )\n    return weight_config_path, selected_format\n\n\ndef _check_pytorch(weight_path: Path) -> Optional[Path]:\n    pytorch_json_path = weight_path / \"pytorch_model.bin.index.json\"\n    if pytorch_json_path.exists():\n        logger.info(\n            \"%s source weight format: huggingface-torch. Source configuration: %s\",\n            FOUND,\n            pytorch_json_path,\n        )\n        return pytorch_json_path\n\n    pytorch_file_path = weight_path / \"pytorch_model.bin\"\n    if pytorch_file_path.exists():\n        logger.info(\n            \"%s source weight format: huggingface-torch. Source configuration: %s\",\n            FOUND,\n            pytorch_file_path,\n        )\n        return pytorch_file_path\n\n    logger.info(\"%s Huggingface PyTorch\", NOT_FOUND)\n    return None\n\n\ndef _check_safetensor(weight_path: Path) -> Optional[Path]:\n    safetensor_json_path = weight_path / \"model.safetensors.index.json\"\n    if safetensor_json_path.exists():\n        logger.info(\n            \"%s source weight format: huggingface-safetensor. Source configuration: %s\",\n            FOUND,\n            safetensor_json_path,\n        )\n        return safetensor_json_path\n\n    safetensor_file_path = weight_path / \"model.safetensors\"\n    if safetensor_file_path.exists():\n        from safetensors.torch import (  # pylint: disable=import-outside-toplevel,import-error\n            load_file,\n        )\n\n        weights = load_file(safetensor_file_path, device=\"cpu\")\n        weight_map = {key: \"model.safetensors\" for key in weights}\n        with open(safetensor_json_path, \"w\", encoding=\"utf-8\") as file:\n            json.dump({\"weight_map\": weight_map}, file, indent=2)\n        logger.info(\n            \"%s source weight format: huggingface-safetensor. Source configuration: %s\",\n            FOUND,\n            safetensor_json_path,\n        )\n        return safetensor_json_path\n\n    logger.info(\"%s Huggingface Safetensor\", NOT_FOUND)\n    return None\n\n\nCHECK_FORMAT_METHODS = {\n    \"huggingface-torch\": _check_pytorch,\n    \"huggingface-safetensor\": _check_safetensor,\n}\n\n# \"ggml\", \"gguf\" are not supported yet.\nAVAILABLE_WEIGHT_FORMAT = [\"huggingface-torch\", \"huggingface-safetensor\", \"awq\"]\n"
  },
  {
    "path": "python/mlc_llm/support/config.py",
    "content": "\"\"\"\nA common base class for configuration. A configuration could be initialized from its constructor,\na JSON string or a JSON file, and irrelevant fields during initialization are automatically moved\nto the `kwargs` field.\n\nTake model configuration as an example: it is usually a JSON file in HuggingFace that contains\nthe model's hyperparameters. For instance, Vicuna-13b-v1.5-16k contains the following\n[JSON file](https://huggingface.co/lmsys/vicuna-13b-v1.5-16k/blob/main/config.json).\nThe base class allows us to load the configuration from this JSON file, moving irrelevant fields\ninto `kwargs`, such as `transformers_version` and `use_cache`.\n\"\"\"\n\n# pylint: disable=too-few-public-methods\nimport dataclasses\nimport json\nfrom pathlib import Path\nfrom typing import Any, Dict, Type, TypeVar\n\nfrom . import logging\nfrom .style import bold, red\n\nlogger = logging.getLogger(__name__)\n\nConfigClass = TypeVar(\"ConfigClass\", bound=\"ConfigBase\")\n\n\n@dataclasses.dataclass\nclass ConfigBase:\n    \"\"\"Base class for configurations, providing a common interface for loading configs from a\n    JSON file or a dict. It requires the subclasses to be dataclasses, and has an `kwargs` field\n    that stores the extra fields that are not defined in the dataclass.\n    \"\"\"\n\n    @classmethod\n    def from_dict(cls: Type[ConfigClass], source: Dict[str, Any]) -> ConfigClass:\n        \"\"\"Create a config object from a dictionary.\n\n        Parameters\n        ----------\n        source : Dict[str, Any]\n            Source to create config from, usually loaded from `config.json` in HuggingFace style.\n\n        Returns\n        -------\n        cfg : ConfigClass\n            An instance of the config object.\n        \"\"\"\n        field_names = [field.name for field in dataclasses.fields(cls)]  # type: ignore[arg-type]\n        fields = {k: v for k, v in source.items() if k in field_names}\n        kwargs = {k: v for k, v in source.items() if k not in field_names}\n        return cls(**fields, kwargs=kwargs)  # type: ignore[call-arg]\n\n    @classmethod\n    def from_file(cls: Type[ConfigClass], source: Path) -> ConfigClass:\n        \"\"\"Create a config object from a file.\n\n        Parameters\n        ----------\n        cfg_cls : Type[ConfigClass]\n            The config class to create, for example, LlamaConfig.\n\n        source : pathlib.Path\n            Path to the source file, usually `config.json` in HuggingFace repo.\n\n        Returns\n        -------\n        cfg : ConfigClass\n            An instance of the config object.\n        \"\"\"\n        with source.open(\"r\", encoding=\"utf-8\") as in_file:\n            return cls.from_dict(json.load(in_file))\n\n    def asdict(self):\n        \"\"\"Convert the config object to a dictionary.\n\n        Returns\n        -------\n        Dict[str, Any]\n            A dictionary representation of the config object.\n        \"\"\"\n        result = dataclasses.asdict(self)\n        result.pop(\"kwargs\")\n        return result\n\n\nclass ConfigOverrideBase:\n    \"\"\"Base class for ConfigOverride, providing a common interface for overriding configs.\n    It requires the subclasses to be dataclasses.\n    \"\"\"\n\n    def apply(self, config):\n        \"\"\"Apply the overrides to the given config.\"\"\"\n        updated = config.asdict()\n        for field in dataclasses.fields(self):\n            key = field.name\n            value = getattr(self, key)\n            if value is None:\n                continue\n            if key not in updated:\n                logger.warning(\n                    \"%s: Cannot override %s, because %s does not have this field\",\n                    red(\"Warning\"),\n                    bold(key),\n                    bold(type(config).__name__),\n                )\n            else:\n                logger.info(  # pylint: disable=logging-fstring-interpolation\n                    f\"Overriding {bold(key)} from {updated[key]} to {value}\"\n                )\n                updated[key] = value\n        return type(config).from_dict(updated)\n\n\n__all__ = [\"ConfigBase\", \"ConfigOverrideBase\"]\n"
  },
  {
    "path": "python/mlc_llm/support/constants.py",
    "content": "\"\"\"Environment variables used by the MLC LLM.\"\"\"\n\nimport os\nimport sys\nfrom pathlib import Path\nfrom typing import List\n\nMLC_CHAT_CONFIG_VERSION = \"0.1.0\"\n\n\ndef _check():\n    if MLC_JIT_POLICY not in [\"ON\", \"OFF\", \"REDO\", \"READONLY\"]:\n        raise ValueError(\n            'Invalid MLC_JIT_POLICY. It has to be one of \"ON\", \"OFF\", \"REDO\", \"READONLY\"'\n            f\"but got {MLC_JIT_POLICY}.\"\n        )\n\n    if MLC_DOWNLOAD_CACHE_POLICY not in [\"ON\", \"OFF\", \"REDO\", \"READONLY\"]:\n        raise ValueError(\n            \"Invalid MLC_AUTO_DOWNLOAD_POLICY. \"\n            'It has to be one of \"ON\", \"OFF\", \"REDO\", \"READONLY\"'\n            f\"but got {MLC_DOWNLOAD_CACHE_POLICY}.\"\n        )\n\n\ndef _get_cache_dir() -> Path:\n    if \"MLC_LLM_HOME\" in os.environ:\n        result = Path(os.environ[\"MLC_LLM_HOME\"])\n    elif sys.platform == \"win32\":\n        result = Path(os.environ[\"LOCALAPPDATA\"])\n        result = result / \"mlc_llm\"\n    elif os.getenv(\"XDG_CACHE_HOME\", None) is not None:\n        result = Path(os.getenv(\"XDG_CACHE_HOME\"))\n        result = result / \"mlc_llm\"\n    else:\n        result = Path(os.path.expanduser(\"~/.cache\"))\n        result = result / \"mlc_llm\"\n    result.mkdir(parents=True, exist_ok=True)\n    if not result.is_dir():\n        raise ValueError(\n            f\"The default cache directory is not a directory: {result}. \"\n            \"Use environment variable MLC_LLM_HOME to specify a valid cache directory.\"\n        )\n    (result / \"model_weights\").mkdir(parents=True, exist_ok=True)\n    (result / \"model_lib\").mkdir(parents=True, exist_ok=True)\n    return result\n\n\ndef _get_dso_suffix() -> str:\n    if \"MLC_DSO_SUFFIX\" in os.environ:\n        return os.environ[\"MLC_DSO_SUFFIX\"]\n    if sys.platform == \"win32\":\n        return \"dll\"\n    if sys.platform == \"darwin\":\n        return \"dylib\"\n    return \"so\"\n\n\ndef _get_test_model_path() -> List[Path]:\n    paths = []\n    if \"MLC_LLM_TEST_MODEL_PATH\" in os.environ:\n        paths += [Path(p) for p in os.environ[\"MLC_LLM_TEST_MODEL_PATH\"].split(os.pathsep)]\n    # by default, we reuse the cache dir via mlc_llm chat\n    # note that we do not auto download for testcase\n    # to avoid networking dependencies\n    base_list = [\"hf\"]\n    paths += [_get_cache_dir() / \"model_weights\" / base / \"mlc-ai\" for base in base_list] + [\n        Path(os.path.abspath(os.path.curdir)),\n        Path(os.path.abspath(os.path.curdir)) / \"dist\",\n    ]\n    return paths\n\n\ndef _get_read_only_weight_caches() -> List[Path]:\n    if \"MLC_LLM_READONLY_WEIGHT_CACHE\" in os.environ:\n        return [Path(p) for p in os.environ[\"MLC_LLM_READONLY_WEIGHT_CACHE\"].split(os.pathsep)]\n    return []\n\n\nMLC_TEMP_DIR = os.getenv(\"MLC_TEMP_DIR\", None)\nMLC_MULTI_ARCH = os.environ.get(\"MLC_MULTI_ARCH\", None)\nMLC_JIT_POLICY = os.environ.get(\"MLC_JIT_POLICY\", \"ON\")\nMLC_DSO_SUFFIX = _get_dso_suffix()\nMLC_TEST_MODEL_PATH: List[Path] = _get_test_model_path()\n\nMLC_DOWNLOAD_CACHE_POLICY = os.environ.get(\"MLC_DOWNLOAD_CACHE_POLICY\", \"ON\")\nMLC_LLM_HOME: Path = _get_cache_dir()\nMLC_LLM_READONLY_WEIGHT_CACHE = _get_read_only_weight_caches()\n\n_check()\n"
  },
  {
    "path": "python/mlc_llm/support/convert_tiktoken.py",
    "content": "\"\"\"\nAdapted from https://gist.github.com/xenova/a452a6474428de0182b17605a98631ee\nGenerator of mlc-chat-config.json and tokenizer configuration.\n\"\"\"\n\n# pylint: disable=import-error\n# isort: off\nimport json\nimport os\nfrom typing import Dict, List, Optional\n\n\ndef bpe(\n    mergeable_ranks: Dict[bytes, int], token: bytes, max_rank: Optional[int] = None\n) -> List[bytes]:\n    \"\"\"Adapted from https://github.com/openai/tiktoken/issues/60#issuecomment-1499977960\"\"\"\n    parts = [bytes([b]) for b in token]\n    while True:\n        min_idx = None\n        min_rank = None\n        for i, pair in enumerate(zip(parts[:-1], parts[1:])):\n            rank = mergeable_ranks.get(pair[0] + pair[1])\n            if rank is not None and (min_rank is None or rank < min_rank):\n                min_idx = i\n                min_rank = rank\n        if min_rank is None or (max_rank is not None and min_rank >= max_rank):\n            break\n        assert min_idx is not None\n        parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2 :]\n    return parts\n\n\ndef generate_vocab_and_merges(encoder, mergeable_ranks):\n    \"\"\"Generate vocab and merges in huggingface tokenizers format\"\"\"\n\n    from transformers.models.gpt2.tokenization_gpt2 import (  # pylint: disable=import-outside-toplevel\n        bytes_to_unicode,\n    )\n\n    byte_encoder = bytes_to_unicode()\n\n    def token_bytes_to_string(b):\n        \"\"\"Convert a token from bytes to a string\"\"\"\n        return \"\".join([byte_encoder[ord(char)] for char in b.decode(\"latin-1\")])\n\n    merges = []\n    vocab = {}\n    for token, rank in mergeable_ranks.items():\n        vocab[token_bytes_to_string(token)] = rank\n\n        if len(token) == 1:\n            continue\n        merged = tuple(bpe(mergeable_ranks, token, max_rank=rank))\n        assert len(merged) == 2\n\n        merges.append(\" \".join(map(token_bytes_to_string, merged)))\n\n    # Also add special tokens\n    vocab.update(encoder._special_tokens)  # pylint: disable=protected-access\n\n    return vocab, merges\n\n\ndef convert_tiktoken(model_path, output_dir, context_window_size=None):\n    \"\"\"Convert tiktoken tokenizers to huggingface tokenizers style\"\"\"\n    try:\n        from transformers import AutoTokenizer  # pylint: disable=import-outside-toplevel\n    except ImportError:\n        raise ImportError(  # pylint: disable=raise-missing-from\n            'Converting tiktoken tokenizer requires the \"transformers\" package.'\n            'Please install the \"transformers\" package to convert toktoken tokenizer'\n        )\n\n    tiktoken_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n    encoder = tiktoken_tokenizer.tokenizer\n\n    vocab, merges = generate_vocab_and_merges(encoder, tiktoken_tokenizer.get_vocab())\n\n    added_tokens = [\n        {\n            \"id\": id,\n            \"content\": content,\n            \"single_word\": False,\n            \"lstrip\": False,\n            \"rstrip\": False,\n            \"normalized\": False,\n            \"special\": True,\n        }\n        for content, id in encoder._special_tokens.items()  # pylint: disable=protected-access\n    ]\n\n    tokenizer_template = {\n        \"version\": \"1.0\",\n        \"truncation\": None,\n        \"padding\": None,\n        \"added_tokens\": added_tokens,\n        \"normalizer\": None,\n        \"pre_tokenizer\": {\n            \"type\": \"ByteLevel\",\n            \"add_prefix_space\": False,\n            \"trim_offsets\": True,\n            \"use_regex\": True,\n        },\n        \"post_processor\": {\n            \"type\": \"ByteLevel\",\n            \"add_prefix_space\": True,\n            \"trim_offsets\": False,\n            \"use_regex\": True,\n        },\n        \"decoder\": {\n            \"type\": \"ByteLevel\",\n            \"add_prefix_space\": True,\n            \"trim_offsets\": True,\n            \"use_regex\": True,\n        },\n        \"model\": {\n            \"type\": \"BPE\",\n            \"dropout\": None,\n            \"unk_token\": None,\n            \"continuing_subword_prefix\": \"\",\n            \"end_of_word_suffix\": \"\",\n            \"fuse_unk\": False,\n            \"byte_fallback\": False,\n            \"vocab\": vocab,\n            \"merges\": merges,\n        },\n    }\n\n    tokenizer_config_template = {\n        \"add_prefix_space\": False,\n        \"bos_token\": \"<|endoftext|>\",\n        \"clean_up_tokenization_spaces\": True,\n        \"eos_token\": \"<|endoftext|>\",\n        \"unk_token\": \"<|endoftext|>\",\n    }\n\n    tokenizer_name = type(tiktoken_tokenizer).__name__\n\n    tokenizer_config_template[\"tokenizer_class\"] = tokenizer_name\n    if context_window_size:\n        tokenizer_config_template[\"model_max_length\"] = context_window_size\n    tokenizer_config_template = dict(sorted(tokenizer_config_template.items(), key=lambda x: x[0]))\n\n    os.makedirs(output_dir, exist_ok=True)\n\n    # Save to files\n    with open(os.path.join(output_dir, \"vocab.json\"), \"w\", encoding=\"utf-8\") as fp:\n        json.dump(vocab, fp, indent=2, ensure_ascii=False)\n\n    with open(os.path.join(output_dir, \"tokenizer.json\"), \"w\", encoding=\"utf-8\") as fp:\n        json.dump(tokenizer_template, fp, indent=2, ensure_ascii=False)\n\n    with open(os.path.join(output_dir, \"tokenizer_config.json\"), \"w\", encoding=\"utf-8\") as fp:\n        json.dump(tokenizer_config_template, fp, indent=2, ensure_ascii=False)\n\n    with open(os.path.join(output_dir, \"special_tokens_map.json\"), \"w\", encoding=\"utf-8\") as fp:\n        json.dump(\n            {\n                \"bos_token\": \"<|endoftext|>\",\n                \"eos_token\": \"<|endoftext|>\",\n                \"unk_token\": \"<|endoftext|>\",\n            },\n            fp,\n            indent=2,\n            ensure_ascii=False,\n        )\n\n    with open(os.path.join(output_dir, \"merges.txt\"), \"w\", encoding=\"utf-8\") as fp:\n        fp.write(\"#version: 0.2\\n\")\n        fp.write(\"\\n\".join(merges))\n"
  },
  {
    "path": "python/mlc_llm/support/download_cache.py",
    "content": "\"\"\"Common utilities for downloading files from HuggingFace or other URLs online.\"\"\"\n\nimport concurrent.futures as cf\nimport hashlib\nimport json\nimport os\nimport shutil\nimport subprocess\nimport tempfile\nfrom pathlib import Path\nfrom typing import List, Optional, Tuple\n\nimport requests  # pylint: disable=import-error\n\nfrom . import logging, tqdm\nfrom .constants import (\n    MLC_DOWNLOAD_CACHE_POLICY,\n    MLC_LLM_HOME,\n    MLC_LLM_READONLY_WEIGHT_CACHE,\n    MLC_TEMP_DIR,\n)\nfrom .style import bold\n\nlogger = logging.getLogger(__name__)\n\n\ndef log_download_cache_policy():\n    \"\"\"log current download policy\"\"\"\n    logger.info(\n        \"%s = %s. Can be one of: ON, OFF, REDO, READONLY\",\n        bold(\"MLC_DOWNLOAD_CACHE_POLICY\"),\n        MLC_DOWNLOAD_CACHE_POLICY,\n    )\n\n\ndef _ensure_directory_not_exist(path: Path, force_redo: bool) -> None:\n    if path.exists():\n        if force_redo:\n            logger.info(\"Deleting existing directory: %s\", path)\n            shutil.rmtree(path)\n        else:\n            raise ValueError(f\"Directory already exists: {path}\")\n    else:\n        path.parent.mkdir(parents=True, exist_ok=True)\n\n\ndef git_clone(url: str, destination: Path, ignore_lfs: bool) -> None:\n    \"\"\"Clone a git repository into a directory.\"\"\"\n    repo_name = \".tmp\"\n    command = [\"git\", \"clone\", url, repo_name]\n    _ensure_directory_not_exist(destination, force_redo=False)\n    try:\n        env = os.environ.copy()\n        env[\"GIT_LFS_SKIP_SMUDGE\"] = \"1\"\n        with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir:\n            logger.info(\"[Git] Cloning %s to %s\", bold(url), destination)\n            subprocess.run(\n                command,\n                env=env,\n                cwd=tmp_dir,\n                check=True,\n                stdout=subprocess.DEVNULL,\n                stderr=subprocess.DEVNULL,\n            )\n            git_dir = os.path.join(tmp_dir, repo_name)\n            if not ignore_lfs:\n                git_lfs_pull(Path(git_dir))\n            shutil.move(git_dir, str(destination))\n    except subprocess.CalledProcessError as error:\n        raise ValueError(\n            f\"Git clone failed with return code {error.returncode}: {error.stderr}. \"\n            f\"The command was: {command}\"\n        ) from error\n\n\ndef git_lfs_pull(repo_dir: Path, ignore_extensions: Optional[List[str]] = None) -> None:\n    \"\"\"Pull files with Git LFS.\"\"\"\n    filenames = (\n        subprocess.check_output(\n            [\"git\", \"-C\", str(repo_dir), \"lfs\", \"ls-files\", \"-n\"],\n            stderr=subprocess.STDOUT,\n        )\n        .decode(\"utf-8\")\n        .splitlines()\n    )\n    if ignore_extensions is not None:\n        filenames = [\n            filename\n            for filename in filenames\n            if not any(filename.endswith(extension) for extension in ignore_extensions)\n        ]\n    logger.info(\"[Git LFS] Downloading %d files with Git LFS: %s\", len(filenames), filenames)\n    with tqdm.redirect():\n        for file in tqdm.tqdm(filenames):\n            logger.info(\"[Git LFS] Downloading %s\", file)\n            subprocess.check_output(\n                [\"git\", \"-C\", str(repo_dir), \"lfs\", \"pull\", \"--include\", file],\n                stderr=subprocess.STDOUT,\n            )\n\n\ndef download_file(\n    url: str,\n    destination: Path,\n    md5sum: Optional[str],\n) -> Tuple[str, Path]:\n    \"\"\"Download a file from a URL to a destination file.\"\"\"\n    with requests.get(url, stream=True, timeout=30) as response:\n        response.raise_for_status()  # type: ignore\n        with destination.open(\"wb\") as file:\n            for chunk in response.iter_content(chunk_size=8192):  # type: ignore\n                file.write(chunk)\n    if md5sum is not None:\n        hash_md5 = hashlib.md5()\n        with destination.open(\"rb\") as file:\n            for chunk in iter(lambda: file.read(8192), b\"\"):\n                hash_md5.update(chunk)\n        file_md5 = hash_md5.hexdigest()\n        if file_md5 != md5sum:\n            raise ValueError(\n                f\"MD5 checksum mismatch for downloaded file: {destination}. \"\n                f\"Expected {md5sum}, got {file_md5}\"\n            )\n    return url, destination\n\n\ndef download_and_cache_mlc_weights(  # pylint: disable=too-many-locals\n    model_url: str,\n    num_processes: int = 4,\n    force_redo: Optional[bool] = None,\n) -> Path:\n    \"\"\"Download weights for a model from the HuggingFace Git LFS repo.\"\"\"\n    log_download_cache_policy()\n    if MLC_DOWNLOAD_CACHE_POLICY == \"OFF\":\n        raise RuntimeError(f\"Cannot download {model_url} as MLC_DOWNLOAD_CACHE_POLICY=OFF\")\n\n    prefixes, mlc_prefix = [\"HF://\", \"https://huggingface.co/\"], \"\"\n    mlc_prefix = next(p for p in prefixes if model_url.startswith(p))\n    assert mlc_prefix\n\n    git_url_template = \"https://huggingface.co/{user}/{repo}\"\n    bin_url_template = \"https://huggingface.co/{user}/{repo}/resolve/main/{record_name}\"\n\n    if model_url.count(\"/\") != 1 + mlc_prefix.count(\"/\") or not model_url.startswith(mlc_prefix):\n        raise ValueError(f\"Invalid model URL: {model_url}\")\n    user, repo = model_url[len(mlc_prefix) :].split(\"/\")\n    domain = \"hf\"\n\n    readonly_cache_dirs = []\n    for base in MLC_LLM_READONLY_WEIGHT_CACHE:\n        cache_dir = base / domain / user / repo\n        readonly_cache_dirs.append(str(cache_dir))\n        if (cache_dir / \"mlc-chat-config.json\").is_file():\n            logger.info(\"Use cached weight: %s\", bold(str(cache_dir)))\n            return cache_dir\n\n    if force_redo is None:\n        force_redo = MLC_DOWNLOAD_CACHE_POLICY == \"REDO\"\n\n    git_dir = MLC_LLM_HOME / \"model_weights\" / domain / user / repo\n    readonly_cache_dirs.append(str(git_dir))\n\n    try:\n        _ensure_directory_not_exist(git_dir, force_redo=force_redo)\n    except ValueError:\n        logger.info(\"Weights already downloaded: %s\", bold(str(git_dir)))\n        return git_dir\n\n    if MLC_DOWNLOAD_CACHE_POLICY == \"READONLY\":\n        raise RuntimeError(\n            f\"Cannot find cache for {model_url}, \"\n            \"cannot proceed to download as MLC_DOWNLOAD_CACHE_POLICY=READONLY, \"\n            \"please check settings MLC_LLM_READONLY_WEIGHT_CACHE, \"\n            f\"local path candidates: {readonly_cache_dirs}\"\n        )\n\n    with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir_prefix:\n        tmp_dir = Path(tmp_dir_prefix) / \"tmp\"\n        git_url = git_url_template.format(user=user, repo=repo)\n        git_clone(git_url, tmp_dir, ignore_lfs=True)\n        git_lfs_pull(tmp_dir, ignore_extensions=[\".bin\"])\n        shutil.rmtree(tmp_dir / \".git\", ignore_errors=True)\n        with (tmp_dir / \"tensor-cache.json\").open(encoding=\"utf-8\") as in_file:\n            param_metadata = json.load(in_file)[\"records\"]\n        with cf.ProcessPoolExecutor(max_workers=num_processes) as executor:\n            futures = []\n            for record in param_metadata:\n                record_name = record[\"dataPath\"]\n                file_url = bin_url_template.format(user=user, repo=repo, record_name=record_name)\n                file_dest = tmp_dir / record_name\n                file_md5 = record.get(\"md5sum\", None)\n                futures.append(executor.submit(download_file, file_url, file_dest, file_md5))\n            with tqdm.redirect():\n                for future in tqdm.tqdm(cf.as_completed(futures), total=len(futures)):\n                    file_url, file_dest = future.result()\n                    logger.info(\"Downloaded %s to %s\", file_url, file_dest)\n        logger.info(\"Moving %s to %s\", tmp_dir, bold(str(git_dir)))\n        shutil.move(str(tmp_dir), str(git_dir))\n    return git_dir\n\n\ndef get_or_download_model(model: str) -> Path:\n    \"\"\"Use user-provided argument ``model`` to get model_path\n\n    We define \"valid\" as having an ``mlc-chat-config.json`` right under the folder.\n\n    Parameters\n    ----------\n    model : str\n        User's input; may a path or url\n\n    Returns\n    ------\n    model_path : Path\n        A \"valid\" path to model folder, with\n        ``(model_path / \"mlc-chat-config.json\").is_file`` being True\n\n    Note\n    ----\n    This function may perform additional download and caching\n\n    Raises\n    ------\n    FileNotFoundError: if we cannot find a valid `model_path`.\n    \"\"\"\n    if model.startswith(\"HF://\"):\n        logger.info(\"Downloading model from HuggingFace: %s\", model)\n        model_path = download_and_cache_mlc_weights(model)\n    else:\n        model_path = Path(model)\n\n    if not model_path.is_dir():\n        raise FileNotFoundError(f\"Cannot find model {model}, directory does not exist\")\n    mlc_config_path = model_path / \"mlc-chat-config.json\"\n    if mlc_config_path.is_file():\n        return model_path\n    raise FileNotFoundError(f\"Cannot find {str(mlc_config_path)} in the model directory provided\")\n"
  },
  {
    "path": "python/mlc_llm/support/logging.py",
    "content": "\"\"\"\nLogging support for MLC. It derives from Python's logging module, and in the future,\nit can be easily replaced by other logging modules such as structlog.\n\"\"\"\n\nimport logging\nimport os\n\n\ndef enable_logging():\n    \"\"\"Enable MLC's default logging format\"\"\"\n    if os.getenv(\"MLC_UNSET_LOGGING\"):\n        return\n    logging.basicConfig(\n        level=logging.INFO,\n        style=\"{\",\n        datefmt=\"%Y-%m-%d %H:%M:%S\",\n        format=\"[{asctime}] {levelname} {filename}:{lineno}: {message}\",\n    )\n\n\ndef getLogger(name: str):  # pylint: disable=invalid-name\n    \"\"\"Get a logger according to the given name\"\"\"\n    return logging.getLogger(name)\n"
  },
  {
    "path": "python/mlc_llm/support/max_thread_check.py",
    "content": "\"\"\"Helper functions for checking max num thread.\"\"\"\n\nfrom tvm.target import Target\n\n\ndef get_max_num_threads_per_block(target: Target) -> int:\n    \"\"\"\n    max(max_num_threads, max_threads_per_block); if latter does not exist, return max_num_threads.\n    We add this method since some targets have both fields and `max_threads_per_block` is larger.\n    \"\"\"\n    max_num_threads = target.attrs.get(\"max_num_threads\")\n    max_threads_per_block = target.attrs.get(\"max_threads_per_block\", None)\n    if max_threads_per_block is None:\n        return max_num_threads\n    return max(max_num_threads, max_threads_per_block)\n\n\ndef check_thread_limits(target: Target, bdx: int, bdy: int, bdz: int, gdz: int):\n    \"\"\"\n    Check whether max num threads exceeded given a target.\n\n    Parameters\n    ----------\n    bdx: threadIdx.x\n    bdy: threadIdx.y\n    bdz: threadIdx.z\n    gdz: blockIdx.z\n    \"\"\"\n    max_num_threads_per_block = get_max_num_threads_per_block(target)\n\n    assert (\n        bdx * bdy * bdz <= max_num_threads_per_block\n    ), f\"{target.kind} max num threads exceeded: {bdx}*{bdy}*{bdz}>{max_num_threads_per_block}\"\n\n    if str(target.kind) == \"webgpu\":\n        # https://gpuweb.github.io/gpuweb/#dom-supported-limits-maxcomputeworkgroupsizez\n        assert bdz <= 64, f\"webgpu's threadIdx.z cannot exceed 64, but got bdz={bdz}\"\n        assert gdz == 1, f\"webgpu's blockIdx.z should be 1, but got gdz={gdz}\"\n"
  },
  {
    "path": "python/mlc_llm/support/preshard.py",
    "content": "\"\"\"Functions for pre-sharding weights\"\"\"\n\nimport logging\nfrom typing import Any, Callable, Dict, Sequence, Tuple\n\nfrom tvm import IRModule, relax\nfrom tvm.relax.frontend import nn\nfrom tvm.runtime import Device, Tensor\nfrom tvm.s_tir import dlight as dl\nfrom tvm.target import Target\n\nlogger = logging.getLogger(\"preshard\")\n\n\ndef _sharded_param_name(param_name, worker_id):\n    return f\"{param_name}_shard-{worker_id}\"\n\n\ndef _create_shard_func(\n    bb: relax.BlockBuilder, param: nn.Parameter, tensor_parallel_shards: int\n):  # pylint: disable=too-many-locals\n    shard_strategy = param.attrs.get(\"shard_strategy\", None)\n    # generate tir shard function\n    tir_func = shard_strategy.gen_tir(shards=tensor_parallel_shards, weight=param)\n    tir_func = tir_func.with_attr(\"global_symbol\", f\"{shard_strategy.name}_tir\")\n    # add tir shard function to the IRModule\n    tir_gvar = bb.add_func(tir_func, func_name=f\"{shard_strategy.name}_tir\")\n    # create relax function that\n    #     1. shard weight with tir shard function, result: [num_shards, *sharded_weight_shape]\n    #     2. split the sharded weight along dim 0, result: num_shards * [1, *sharded_weight_shape]\n    #     3. squeeze the 0th-dim of all shards, result: num_shards * [*sharded_weight_shape]\n    weight_shape = param.shape\n    weight_shape[shard_strategy.dim] = weight_shape[shard_strategy.dim] * tensor_parallel_shards\n    sharded_weight_shape = [tensor_parallel_shards, *param.shape]\n    weight_var = relax.Var(\"weight\", relax.TensorStructInfo(weight_shape, param.dtype))\n    with bb.function(name=shard_strategy.name, params=[weight_var]):\n        with bb.dataflow():\n            lv0 = bb.emit(\n                relax.call_tir(\n                    tir_gvar,\n                    weight_var,\n                    out_sinfo=relax.TensorStructInfo(sharded_weight_shape, param.dtype),\n                )\n            )\n            lv1 = bb.emit(relax.op.split(lv0, indices_or_sections=tensor_parallel_shards, axis=0))\n            output_vars = []\n            for i in range(tensor_parallel_shards):\n                lvi = bb.emit(relax.TupleGetItem(lv1, i))\n                squeezed_lvi = bb.emit(relax.op.squeeze(lvi, 0))\n                output_vars.append(squeezed_lvi)\n            gv = bb.emit_output(output_vars)\n        bb.emit_func_output(gv)\n\n\ndef _compile_shard_funcs(mod: IRModule, device: Device):\n    target = Target.from_device(device)\n    with target:\n        mod = relax.transform.LegalizeOps()(mod)\n        mod = dl.ApplyDefaultSchedule(  # type: ignore   # pylint: disable=not-callable\n            dl.gpu.Matmul(),\n            dl.gpu.GEMV(),\n            dl.gpu.Reduction(),\n            dl.gpu.GeneralReduction(),\n            dl.gpu.Fallback(),\n        )(mod)\n    ex = relax.build(mod, target=target)\n    vm = relax.VirtualMachine(ex, device)\n    return vm\n\n\ndef apply_preshard(\n    named_params: Dict[str, nn.Parameter],\n    tensor_parallel_shards: int,\n    args: Any,\n) -> Tuple[Dict[str, nn.Parameter], Dict[str, Callable[[Tensor], Sequence[Tensor]]]]:\n    \"\"\"Apply pre-sharding to the named parameters.\n\n    Parameters\n    ----------\n    named_params : Dict[str, nn.Parameter]\n        The named parameters of the model. If the model is quantized, the named parameters should\n        the state dictionary of the quantized model.\n    tensor_parallel_shards : int\n        The number of tensor parallel shards.\n    args : Any\n        The parsed arguments of weight conversion.\n\n    Returns\n    -------\n    Tuple[Dict[str, nn.Parameter], Dict[str, Callable[[Tensor], Sequence[Tensor]]]\n        The updated named parameters and the mapping from parameter name to the shard function.\n    \"\"\"\n    bb = relax.BlockBuilder()\n    param_to_shard_func = {}\n    shard_func_names = set()\n    new_named_params: Dict[str, nn.Parameter] = {}\n    has_shard_strategy = False\n    for name, param in named_params.items():\n        shard_strategy = param.attrs.get(\"shard_strategy\", None)\n        if shard_strategy is not None:\n            has_shard_strategy = True\n            for i in range(tensor_parallel_shards):\n                new_named_params[_sharded_param_name(name, i)] = param\n            # create shard functions\n            param_to_shard_func[name] = shard_strategy.name\n            if shard_strategy.name not in shard_func_names:\n                _create_shard_func(bb, param, tensor_parallel_shards)\n                shard_func_names.add(shard_strategy.name)\n        else:\n            new_named_params[name] = param\n\n    if not has_shard_strategy:\n        logger.warning(\n            \"No parameters with 'shard_strategy' found.\"\n            \"At least one parameter must have a 'shard_strategy' for presharding. \"\n            \"The model will continue to convert weights in a non-presharded manner.\"\n        )\n\n    mod = bb.finalize()\n    vm = _compile_shard_funcs(mod, args.device)\n\n    for name in param_to_shard_func:\n        param_to_shard_func[name] = vm[param_to_shard_func[name]]\n    return new_named_params, param_to_shard_func\n"
  },
  {
    "path": "python/mlc_llm/support/random.py",
    "content": "\"\"\"Utility functions for random number generation.\"\"\"\n\nimport sys\n\n\ndef set_global_random_seed(seed):\n    \"\"\"Set global random seed for python, numpy, torch and tvm.\"\"\"\n    if \"numpy\" in sys.modules:\n        sys.modules[\"numpy\"].random.seed(seed)\n    if \"torch\" in sys.modules:\n        sys.modules[\"torch\"].manual_seed(seed)\n    if \"random\" in sys.modules:\n        sys.modules[\"random\"].seed(seed)  # pylint: disable=no-member\n    if \"tvm\" in sys.modules:\n        set_seed = sys.modules[\"tvm\"].get_global_func(\"mlc.random.set_seed\")\n        if set_seed:\n            set_seed(seed)\n"
  },
  {
    "path": "python/mlc_llm/support/style.py",
    "content": "\"\"\"Printing styles.\"\"\"\n\nfrom enum import Enum\n\n\nclass Styles(Enum):\n    \"\"\"Predefined set of styles to be used.\n\n    Reference:\n    - https://en.wikipedia.org/wiki/ANSI_escape_code#3-bit_and_4-bit\n    - https://stackoverflow.com/a/17303428\n    \"\"\"\n\n    RED = \"\\033[91m\"\n    GREEN = \"\\033[92m\"\n    YELLOW = \"\\033[93m\"\n    BLUE = \"\\033[94m\"\n    PURPLE = \"\\033[95m\"\n    CYAN = \"\\033[96m\"\n    BOLD = \"\\033[1m\"\n    UNDERLINE = \"\\033[4m\"\n    END = \"\\033[0m\"\n\n\ndef red(text: str) -> str:\n    \"\"\"Return red text.\"\"\"\n    return f\"{Styles.RED.value}{text}{Styles.END.value}\"\n\n\ndef green(text: str) -> str:\n    \"\"\"Return green text.\"\"\"\n    return f\"{Styles.GREEN.value}{text}{Styles.END.value}\"\n\n\ndef yellow(text: str) -> str:\n    \"\"\"Return yellow text.\"\"\"\n    return f\"{Styles.YELLOW.value}{text}{Styles.END.value}\"\n\n\ndef blue(text: str) -> str:\n    \"\"\"Return blue text.\"\"\"\n    return f\"{Styles.BLUE.value}{text}{Styles.END.value}\"\n\n\ndef purple(text: str) -> str:\n    \"\"\"Return purple text.\"\"\"\n    return f\"{Styles.PURPLE.value}{text}{Styles.END.value}\"\n\n\ndef cyan(text: str) -> str:\n    \"\"\"Return cyan text.\"\"\"\n    return f\"{Styles.CYAN.value}{text}{Styles.END.value}\"\n\n\ndef bold(text: str) -> str:\n    \"\"\"Return bold text.\"\"\"\n    return f\"{Styles.BOLD.value}{text}{Styles.END.value}\"\n\n\ndef underline(text: str) -> str:\n    \"\"\"Return underlined text.\"\"\"\n    return f\"{Styles.UNDERLINE.value}{text}{Styles.END.value}\"\n"
  },
  {
    "path": "python/mlc_llm/support/tensor_parallel.py",
    "content": "\"\"\"Sharding operators for tensor parallelism.\"\"\"\n\nimport dataclasses\nfrom contextlib import contextmanager\nfrom typing import Any, Dict, List, Optional\n\nfrom tvm import te, tir, topi\nfrom tvm.relax.frontend import nn\n\n\n@dataclasses.dataclass\nclass ShardSingleDim:\n    \"\"\"\n    Shard a tensor by a single dimension.\n\n\n    Parameters\n    ----------\n    name : str\n        The name of the shard func\n\n    dim : int\n        The dimension to shard\n\n    segs : Optional[List[int]]\n        The length of segments along `dim`. Default to None. If specified,\n        shard a tensor by its \"segmented\" dimension, where each segment has a different length\n        and sharded evenly on each worker.\n\n    \"\"\"\n\n    name: str\n    dim: int\n    segs: Optional[List[int]] = None\n\n    def gen_tir(self, shards: int, weight: nn.Tensor) -> tir.PrimFunc:\n        \"\"\"Generate a TIR function that shards the weight tensor by its rows.\"\"\"\n        shape = weight.shape\n        segs = self.segs or [shape[self.dim]]\n        assert sum(segs) == shape[self.dim]\n        # NOTE: we use int64 to prevent int32 overflow\n        shape = [tir.IntImm(\"int64\", v) for v in shape]\n        segs = [tir.IntImm(\"int64\", v) for v in segs]\n        w = te.placeholder(\n            [tir.IntImm(\"int64\", v) for v in self._compute_in_shape(shards, weight)],\n            weight.dtype,\n            name=\"w\",\n        )\n        ws: List[te.Tensor] = []\n        offset = 0\n        for idx, sub_seg in enumerate(segs):\n            ws.append(\n                topi.transpose(\n                    topi.reshape(\n                        te.compute(\n                            (\n                                *shape[: self.dim],\n                                sub_seg * shards,\n                                *shape[self.dim + 1 :],\n                            ),\n                            lambda *idx: w[\n                                idx[: self.dim]\n                                + (idx[self.dim] + offset,)  # pylint: disable=cell-var-from-loop\n                                + idx[self.dim + 1 :]\n                            ],\n                            name=f\"w_{idx}\",\n                        ),\n                        (\n                            *shape[: self.dim],\n                            tir.IntImm(\"int64\", shards),\n                            sub_seg,\n                            *shape[self.dim + 1 :],\n                        ),\n                    ),\n                    [self.dim, *range(self.dim), *range(self.dim + 1, len(shape) + 1)],\n                )\n            )\n            offset += sub_seg * shards\n        o = topi.concatenate(ws, axis=1 + self.dim)\n        func = te.create_prim_func([w, o])\n        return func\n\n    def gen_shard_info(self, shards: int, weight: nn.Tensor) -> Dict[str, Any]:\n        \"\"\"Generate shard info for this sharding strategy.\"\"\"\n        return {\n            \"func_name\": self.name,\n            \"in_shape\": self._compute_in_shape(shards, weight),\n            \"out_shape\": (shards, *weight.shape),\n            \"out_dtype\": weight.dtype,\n        }\n\n    def _compute_in_shape(self, shards: int, weight: nn.Tensor) -> List[int]:\n        \"\"\"Compute the weight shape before sharding.\"\"\"\n        shape = weight.shape\n        return [*shape[: self.dim], shape[self.dim] * shards, *shape[self.dim + 1 :]]\n\n\n@contextmanager\ndef shard_bias(linear: nn.Linear, tensor_parallel_shards: int):\n    \"\"\"\n    A context manager to shard the bias of a linear into `tensor_parallel_shards` shards.\n\n\n    Parameters\n    ----------\n    linear : nn.Linear\n        The linear layer whose bias would be sharded.\n\n    tensor_parallel_shards : int\n        The number of shards.\n    \"\"\"\n    original_bias = linear.bias\n    if tensor_parallel_shards > 1:\n        linear.bias = linear.bias / tensor_parallel_shards\n    yield\n    linear.bias = original_bias\n"
  },
  {
    "path": "python/mlc_llm/support/tqdm.py",
    "content": "\"\"\"Utils to better use tqdm\"\"\"\n\nimport contextlib\nimport inspect\nimport io\n\nfrom tqdm import tqdm\nfrom tqdm.contrib.logging import logging_redirect_tqdm as _redirect_logging\n\n\n@contextlib.contextmanager\ndef _redirect_print():\n    old_print = print\n\n    def new_print(*args, **kwargs):\n        with io.StringIO() as output:\n            kwargs[\"file\"] = output\n            kwargs[\"end\"] = \"\"\n            old_print(*args, **kwargs)\n            content = output.getvalue()\n        tqdm.write(content)\n\n    try:\n        inspect.builtins.print = new_print\n        yield\n    finally:\n        inspect.builtins.print = old_print\n\n\n@contextlib.contextmanager\ndef redirect():\n    \"\"\"Redirect tqdm output to logging and print.\"\"\"\n\n    with _redirect_logging():\n        with _redirect_print():\n            yield\n\n\n__all__ = [\"tqdm\", \"redirect\"]\n"
  },
  {
    "path": "python/mlc_llm/testing/__init__.py",
    "content": "\"\"\"\nTest and debug tools for MLC LLM\n\"\"\"\n\nfrom .pytest_utils import require_test_model, require_test_tokenizers\n"
  },
  {
    "path": "python/mlc_llm/testing/debug_chat.py",
    "content": "\"\"\"Debug compiled models with TVM instrument\"\"\"\n\n# pylint: disable=too-many-arguments\nimport json\nimport random\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport tvm\nimport tvm_ffi\nfrom tvm import DataType, relax\nfrom tvm.contrib import tvmjs\nfrom tvm.runtime import Device, Module, Object, ShapeTuple\nfrom tvm.runtime.vm import VirtualMachine\n\nfrom mlc_llm.conversation_template import ConvTemplateRegistry\nfrom mlc_llm.interface.help import HELP\nfrom mlc_llm.protocol.mlc_chat_config import MLCChatConfig\nfrom mlc_llm.serve import data, engine_utils\nfrom mlc_llm.support.argparse import ArgumentParser\nfrom mlc_llm.support.auto_device import detect_device\nfrom mlc_llm.support.style import green, red\nfrom mlc_llm.tokenizers import Tokenizer\n\n\ndef _extract_metadata(mod: Module):\n    return json.loads(VirtualMachine(mod, tvm.runtime.device(\"cpu\"))[\"_metadata\"]())\n\n\ndef _load_params(\n    model_weight_path: str, device: Device, model_metadata: Dict[str, Any]\n) -> List[tvm.runtime.Tensor]:\n    params, meta = tvmjs.load_tensor_cache(model_weight_path, device)\n    param_names = [param[\"name\"] for param in model_metadata[\"params\"]]\n    assert len(param_names) == meta[\"ParamSize\"]\n\n    plist = []\n    for param_name in param_names:\n        plist.append(params[param_name])\n    return plist\n\n\ndef _get_tvm_module(\n    model_weight_path: str,\n    lib_path: str,\n    device: Device,\n    instrument: Union[tvm_ffi.Function, None],\n):\n    ex = tvm.runtime.load_module(lib_path)\n    vm = relax.VirtualMachine(ex, device)\n    if instrument is not None:\n        vm.set_instrument(instrument)\n    metadata = _extract_metadata(ex)\n    params = _load_params(model_weight_path, device, metadata)\n    return vm.module, params, metadata\n\n\nclass DefaultDebugInstrument:\n    \"\"\"The default debug instrument to use if users don't specify\n    a customized one.\n\n    This debug instrument will dump the arguments and output of each\n    VM Call instruction into a .npz file. It will also alert the user\n    if any function outputs are NaN or INF.\n    \"\"\"\n\n    def __init__(self, debug_out: Path):\n        \"\"\"Constructor\n\n        Parameters\n        ----------\n        debug_out : Path\n            the directory to dump the .npz files\n        \"\"\"\n        self.counter = 0\n        self.first_nan_occurred = False\n        self.first_inf_occurred = False\n        self.debug_out = debug_out\n        debug_out.mkdir(exist_ok=True, parents=True)\n\n    def reset(self, debug_out: Path):\n        \"\"\"Reset the state of the Instrument class\n\n        Parameters\n        ----------\n        debug_out : Path\n            the directory to dump the .npz files\n        \"\"\"\n        self.counter = 0\n        self.first_nan_occurred = False\n        self.first_inf_occurred = False\n        self.debug_out = debug_out\n        debug_out.mkdir(exist_ok=True, parents=True)\n\n    def __call__(self, func, name, before_run, ret_val, *args):\n        # Determine what functions to look at\n        if before_run:  # Whether before the function is called or after\n            return\n        if self.first_nan_occurred:\n            return\n        if self.first_inf_occurred:\n            return\n        if (\n            name.startswith(\"vm.builtin.\")\n            and \"call_tir_dyn\" not in name\n            and \"attention_with_fused_qkv\" not in name\n            and \"self_attention\" not in name\n            and \"cross_attention\" not in name\n        ):\n            return\n\n        # Decide what to print or save about the function's arguments (where args[-1] is the\n        # buffer we write the result to)\n        func_name = f\"f{self.counter}_{name}\"\n\n        # Write your own behavior below. For example, we can count the number of INF/NaN in args[-1]\n        def _check_nan_inf(npy):\n            num_nans = np.sum(np.isnan(npy))\n            num_infs = np.sum(np.isinf(npy))\n            if num_nans > 0:\n                print(f\"{red(f'{func_name} has NaN')}: {num_nans}\")\n                self.first_nan_occurred = True\n            if num_infs > 0:\n                print(f\"{red(f'{func_name} has INF')}: {num_infs}\")\n                self.first_inf_occurred = True\n\n        # Save the arguments to npz\n        arg_dict = {}\n        for i, arg in enumerate(args):\n            if isinstance(arg, tvm.runtime.Tensor):\n                if np.prod(arg.shape) * (DataType(arg.dtype).bits // 8) > 2147483648:\n                    # We skip dump large tensors\n                    arg_dict[f\"arg_{i}\"] = np.zeros(())\n                elif arg.dtype in [\"bfloat16\", \"float8_e4m3fn\"]:\n                    arg_dict[f\"arg_{i}\"] = arg.numpy().astype(np.float32)\n                else:\n                    arg_dict[f\"arg_{i}\"] = arg.numpy()\n                _check_nan_inf(arg.numpy())\n        np.savez(self.debug_out / f\"{func_name}.npz\", **arg_dict)\n\n        self.counter += 1\n\n\nclass DebugChat:  # pylint: disable=too-many-instance-attributes, too-few-public-methods\n    \"\"\"A chat interface used only for debugging purpose.\n\n    It debugs auto-regressive decoding fully in Python via the prefill and\n    decode interface. It supports debugging instrument (either default or\n    customized) to dump intermediate values for each VM function call.\n\n    Given a prompt, it also prints out the parsed prompt, input tokens, output\n    tokens and output text.\n\n    Sample usage:\n\n    dc = DebugChat(\n        model=\"./dist/Llama-2-7b-chat-hf-q4f16_1-MLC\",\n        debug_dir=Path(\"./debug-llama-2\"),\n        model_lib=\"./dist/llama-2-7b-chat-q4f16_1-metal.so\",\n    )\n    dc.generate(\"hello world\", 3)\n    \"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments\n        self,\n        model: str,\n        model_lib: str,\n        debug_dir: Path,\n        device: Optional[str] = \"auto\",\n        debug_instrument: Optional[Any] = None,\n        is_image_model: Optional[bool] = False,\n        disable_instrument: Optional[bool] = False,\n    ):\n        \"\"\"_summary_\n\n        Parameters\n        ----------\n        model: str\n            The model folder after compiling with MLC-LLM build process. The parameter\n            can either be the model name with its quantization scheme\n            (e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model\n            folder. In the former case, we will use the provided name to search\n            for the model folder over possible paths.\n\n        model_lib : str\n            The full path to the model library file to use (e.g. a ``.so`` file).\n\n        debug_dir: Path\n            The output folder to store the dumped debug files.\n\n        device : Optional[str]\n            The description of the device to run on. User should provide a string in the\n            form of 'device_name:device_id' or 'device_name', where 'device_name' is one of\n            'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto' (automatically detect the\n            local device), and 'device_id' is the device id to run on. If no 'device_id'\n            is provided, it will be set to 0 by default.\n\n        chat_config : Optional[ChatConfig]\n            A ``ChatConfig`` instance partially filled. Will be used to override the\n            ``mlc-chat-config.json``.\n\n        debug_instrument : Optional[Any]\n            An instrument function that will be called before/after each Call instruction.\n            The function have the following signature:\n\n            .. code:: python\n\n                def instrument(\n                    func: Union[VMClosure, Function],\n                    func_symbol: str,\n                    before_run: bool,\n                    ret_value: any,\n                    *args) -> bool:\n                    pass\n\n            The instrument takes the following parameters:\n            - func: function object to be called.\n            - func_symbol: the symbol name of the function.\n            - before_run: whether it is before or after call.\n            - ret_value: the return value of the call, only valid after run.\n            - args: the arguments being passed to call.\n\n        is_image_model: Optional[bool]\n            Whether the model support image input. If so, will look for image embedding method.\n            Default to False.\n\n        disable_instrument: Optional[bool]\n            If true, will not use debug instrument for faster generation. Default to False.\n        \"\"\"\n        self.debug_dir = debug_dir\n        self.device = detect_device(device)\n        if disable_instrument:\n            self.instrument = None\n        else:\n            self.instrument = (\n                debug_instrument\n                if debug_instrument\n                else DefaultDebugInstrument(debug_dir / \"prefill\")\n            )\n        self.mod, self.params, self.metadata = _get_tvm_module(\n            model, model_lib, self.device, self.instrument\n        )\n        self.model_path = Path(model)\n        self.config_file_path = self.model_path / \"mlc-chat-config.json\"\n        with open(self.config_file_path, mode=\"rt\", encoding=\"utf-8\") as file:\n            self.chat_config = MLCChatConfig.model_validate_json(file.read())\n\n        conv_template = self.chat_config.conv_template\n\n        self.conversation = (\n            ConvTemplateRegistry.get_conv_template(conv_template)\n            if isinstance(conv_template, str)\n            else conv_template\n        )\n        self.tokenizer = Tokenizer(str(self.model_path))\n\n        self.add_sequence_func = tvm.get_global_func(\"vm.builtin.kv_state_add_sequence\")\n        self.begin_forward_func = tvm.get_global_func(\"vm.builtin.kv_state_begin_forward\")\n        self.end_forward_func = tvm.get_global_func(\"vm.builtin.kv_state_end_forward\")\n        self.nd_view_func = tvm.get_global_func(\"vm.builtin.reshape\")\n        self.sample_topp_from_prob_func = tvm.get_global_func(\"vm.builtin.sample_top_p_from_prob\")\n\n        try:\n            self.embed_func = self.mod[\"embed\"]\n        except AttributeError as exc:\n            raise RuntimeError(\"DebugChat only supports separate embedding layer\") from exc\n\n        if is_image_model:\n            try:\n                self.embed_image_func = self.mod[\"image_embed\"]\n            except AttributeError as exc:\n                raise RuntimeError(\n                    \"Expect the model to be an image model, but cannot find `image_embed`.\"\n                ) from exc\n\n        self.prefill_func = self.mod[\"prefill\"]\n        self.decode_func = self.mod[\"decode\"]\n        self.create_kv_cache_func = None\n        if self.mod.implements_function(\"create_flashinfer_paged_kv_cache\"):\n            self.create_kv_cache_func = self.mod[\"create_flashinfer_paged_kv_cache\"]\n        elif self.mod.implements_function(\"create_tir_paged_kv_cache\"):\n            self.create_kv_cache_func = self.mod[\"create_tir_paged_kv_cache\"]\n        else:\n            # TODO: Support RNN KVState # pylint: disable=fixme\n            raise RuntimeError(\"DebugChat cannot find create KV cache function\")\n\n        self.appeared_token_freq: Dict[int, int] = {}\n\n    def _preprocess_prompts(\n        self, prompt: str, image_url: Optional[str] = None\n    ) -> List[Union[List[int], data.ImageData]]:\n        print(\"======================= Starts Tokenization & Embedding =======================\")\n        # Step 0. Generate prompt string using conversation template\n        if image_url is None:\n            self.conversation.messages.append((\"user\", prompt))\n        else:\n            self.conversation.messages.append(\n                (\n                    \"user\",\n                    [\n                        {\"type\": \"image_url\", \"image_url\": image_url},\n                        {\"type\": \"text\", \"text\": prompt},\n                    ],\n                )\n            )\n        self.conversation.messages.append((\"assistant\", None))\n\n        with open(self.config_file_path, \"r\", encoding=\"utf-8\") as file:\n            config = json.load(file)\n        parsed_prompt = self.conversation.as_prompt(config)\n        print(\n            \"Parsed prompt using conversation template \"\n            f\"{green(self.conversation.name)}: {parsed_prompt}\"\n        )\n        tokens = engine_utils.process_prompts(parsed_prompt, self.tokenizer.encode)  # type: ignore\n\n        if self.conversation.system_prefix_token_ids is not None:\n            tokens[0] = self.conversation.system_prefix_token_ids + tokens[0]\n\n        return tokens\n\n    def _embed(\n        self, data_inputs: List[Union[List[int], data.ImageData]]\n    ) -> Tuple[tvm.runtime.Tensor, int]:\n        # We currently convert to numpy after embedded, concat in numpy, then convert back to\n        # tvm tensor; could be more optimized; but may suffice for debug purposes.\n        embeddings = []\n        for data_input in data_inputs:\n            if isinstance(data_input, data.ImageData):\n                # Process image data\n                # print(f\"data_input.get_embed_size(): {data_input.embed_size}\")\n                image_input = data_input.image\n                if data_input.image.device != self.device:\n                    image_input = data_input.image.copyto(self.device)\n                embeddings.append(self.embed_image_func(image_input, self.params).asnumpy())\n            else:\n                # Process token data\n                data_input = tvm.runtime.tensor(\n                    np.array(data_input).astype(\"int32\"), device=self.device\n                )\n                embeddings.append(self.embed_func(data_input, self.params).asnumpy())\n        # for embedding in embeddings:\n        #     print(f\"embedding.shape: {embedding.shape}\")\n\n        # Concatenate\n        concat_embeddings = tvm.runtime.tensor(\n            np.concatenate(embeddings, axis=0), device=self.device\n        )\n        concat_embeddings = self.nd_view_func(\n            concat_embeddings,\n            ShapeTuple([1, concat_embeddings.shape[0], concat_embeddings.shape[1]]),\n        )\n        input_len = concat_embeddings.shape[1]\n\n        return concat_embeddings, input_len\n\n    def _prefill(self, embedding: tvm.runtime.Tensor, input_len: int):\n        print(\"======================= Starts Prefill =======================\")\n        seq_len_shape = ShapeTuple([input_len])\n        max_num_sequence = 1\n        page_size = 16\n        sliding_window_size = (\n            self.chat_config.sliding_window_size\n            if self.chat_config.sliding_window_size\n            else self.metadata[\"sliding_window_size\"]\n        )\n        context_window_size = (\n            self.chat_config.context_window_size\n            if self.chat_config.context_window_size\n            else self.metadata[\"context_window_size\"]\n        )\n        prefill_chunk_size = (\n            self.chat_config.prefill_chunk_size\n            if self.chat_config.prefill_chunk_size\n            else self.metadata[\"prefill_chunk_size\"]\n        )\n        max_total_sequence_length = (\n            sliding_window_size if context_window_size == -1 else context_window_size\n        )\n        support_sliding_window = int(sliding_window_size != -1)\n\n        kv_caches = self.create_kv_cache_func(\n            ShapeTuple([max_num_sequence]),\n            ShapeTuple([max_total_sequence_length]),\n            ShapeTuple([prefill_chunk_size]),\n            ShapeTuple([page_size]),\n            ShapeTuple([support_sliding_window]),\n        )\n        self.add_sequence_func(kv_caches, 0)\n        self.begin_forward_func(kv_caches, ShapeTuple([0]), seq_len_shape)\n        logits, kv_caches = self.prefill_func(embedding, kv_caches, self.params)\n        self.end_forward_func(kv_caches)\n        return logits, kv_caches\n\n    def _decode(self, token: int, kv_caches: Object):\n        embedding, _ = self._embed([[token]])\n        self.begin_forward_func(kv_caches, ShapeTuple([0]), ShapeTuple([1]))\n        logits, kv_caches = self.decode_func(embedding, kv_caches, self.params)\n        self.end_forward_func(kv_caches)\n        return logits\n\n    def _softmax_with_temperature(self, logits: np.ndarray, temperature: float):\n        # Adjust logits based on the temperature\n        logits = np.array(logits) / temperature\n        logits -= np.max(logits, axis=-1, keepdims=True)\n\n        exp_logits = np.exp(logits, logits)\n        exp_logits /= np.sum(exp_logits, axis=-1, keepdims=True)\n        return exp_logits\n\n    def _apply_presence_and_freq_penalty(\n        self, logits: np.ndarray, presence_penalty: float, freq_penalty: float\n    ):\n        for token_id, freq in self.appeared_token_freq.items():\n            logits[:, :, token_id] -= freq * freq_penalty + presence_penalty\n\n    def _sample_token_from_logits(\n        self,\n        logits: tvm.runtime.Tensor,\n        *,\n        temperature=1.0,\n        top_p=1.0,\n        presence_penalty=0.0,\n        frequency_penalty=0.0,\n    ):\n        logits_np = logits.numpy()\n\n        if presence_penalty != 0.0 or frequency_penalty != 0.0:\n            self._apply_presence_and_freq_penalty(logits_np, presence_penalty, frequency_penalty)\n\n        logits_np = self._softmax_with_temperature(logits_np, temperature)\n        if self.instrument is not None:\n            np.savez(self.instrument.debug_out / \"logits.npz\", logits_np)\n\n        logits = logits.copyfrom(logits_np)\n        next_token = self.sample_topp_from_prob_func(logits, top_p, random.random())\n        return next_token\n\n    def generate(\n        self,\n        prompt: str,\n        generate_length: int,\n        image_url: Optional[str] = None,\n    ):\n        \"\"\"Generates the response from the model given a user prompt. User will need to\n        specify the generation length for debugging purpose. For example, a generation\n        length of 3 will include 1 prefill step and 2 decode steps.\n\n        Parameters\n        ----------\n        prompt : str\n            The user input prompt.\n\n        generate_length : int\n            How many tokens to generate.\n        \"\"\"\n        out_tokens = []\n\n        data_inputs = self._preprocess_prompts(prompt, image_url)\n        print(f\"{green('Data inputs: ')}: {data_inputs}\")\n        embedding, input_len = self._embed(data_inputs)\n        logits, kv_caches = self._prefill(embedding, input_len)\n        next_token = self._sample_token_from_logits(logits)\n        out_tokens.append(next_token)\n        if self.instrument is not None:\n            path_str = (self.debug_dir / \"prefill\").as_posix()\n            print(f\"Debug instrument output dumped to {green(path_str)}\")\n\n        print(\"======================= Starts Decode =======================\")\n        for i in range(generate_length - 1):\n            if self.instrument is not None:\n                self.instrument.reset(self.debug_dir / f\"decode_{i}\")\n            logits = self._decode(next_token, kv_caches)\n            next_token = self._sample_token_from_logits(logits)\n            out_tokens.append(next_token)\n            if self.instrument is not None:\n                path_str = (self.debug_dir / f\"decode_{i}\").as_posix()\n                print(f\"Debug instrument output dumped to {green(path_str)}\")\n\n            if next_token in self.conversation.stop_token_ids:\n                break\n\n        print(f\"{green('Generated output tokens')}: {np.array(out_tokens)}\")\n\n        out_text = self.tokenizer.decode(out_tokens)\n        print(f\"{green('Generated output text')}: {out_text}\")\n\n\ndef main():\n    \"\"\"The main function to start a DebugChat CLI\"\"\"\n\n    parser = ArgumentParser(\"MLC LLM Chat Debug Tool\")\n    parser.add_argument(\n        \"prompt\",\n        type=str,\n        help=\"The user input prompt.\",\n    )\n    parser.add_argument(\n        \"--generate-len\",\n        type=int,\n        help=\"Number of output tokens to generate.\",\n        required=True,\n    )\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        help=\"An MLC model directory that contains `mlc-chat-config.json`\",\n        required=True,\n    )\n    parser.add_argument(\n        \"--model-lib\",\n        type=str,\n        help=\"The full path to the model library file to use (e.g. a ``.so`` file).\",\n        required=True,\n    )\n    parser.add_argument(\n        \"--debug-dir\",\n        type=str,\n        help=\"The output folder to store the dumped debug files.\",\n        required=True,\n    )\n    parser.add_argument(\n        \"--device\",\n        type=str,\n        default=\"auto\",\n        help=HELP[\"device_compile\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--image-url\",\n        type=str,\n        required=False,\n        help=\"Image to prefill into the model, can only be set for image models\",\n    )\n    parser.add_argument(\n        \"--disable-instrument\",\n        action=\"store_true\",\n        help=(\n            \"Disable dumping customizable detailed information of kernel input \"\n            + \"and output, hence making generation faster.\"\n        ),\n    )\n    parsed = parser.parse_args()\n    dc = DebugChat(\n        model=parsed.model,\n        model_lib=parsed.model_lib,\n        debug_dir=Path(parsed.debug_dir),\n        device=parsed.device,\n        is_image_model=parsed.image_url is not None,\n        disable_instrument=parsed.disable_instrument,\n    )\n\n    dc.generate(parsed.prompt, parsed.generate_len, parsed.image_url)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/mlc_llm/testing/debug_compare.py",
    "content": "\"\"\"Debug compiled models with TVM instrument\"\"\"\n\nimport os\nfrom pathlib import Path\nfrom typing import Dict, List, Set, Tuple\n\nimport tvm\nfrom tvm import rpc, runtime\nfrom tvm.relax.testing.lib_comparator import LibCompareVMInstrument\n\nfrom mlc_llm.interface.help import HELP\nfrom mlc_llm.support.argparse import ArgumentParser\nfrom mlc_llm.testing.debug_chat import DebugChat\n\n\ndef _print_as_table(sorted_list):\n    print(\"=\" * 100)\n    print(\n        \"Name\".ljust(50)\n        + \"Time (ms)\".ljust(12)\n        + \"Count\".ljust(8)\n        + \"Total time (ms)\".ljust(18)\n        + \"Percentage (%)\"\n    )\n    total_time = sum(record[1][0] * record[1][1] for record in sorted_list) * 1000\n    for record in sorted_list:\n        time = record[1][0] * 1000\n        weighted_time = time * record[1][1]\n        percentage = weighted_time / total_time * 100\n        print(\n            record[0].ljust(50)\n            + f\"{time:.4f}\".ljust(12)\n            + str(record[1][1]).ljust(8)\n            + f\"{weighted_time:.4f}\".ljust(18)\n            + f\"{percentage:.2f}\"\n        )\n    print(f\"Total time: {total_time:.4f} ms\")\n\n\nclass LibCompare(LibCompareVMInstrument):\n    \"\"\"The default debug instrument to use if users don't specify\n    a customized one.\n\n    This debug instrument will dump the arguments and output of each\n    VM Call instruction into a .npz file. It will also alert the user\n    if any function outputs are NaN or INF.\n\n    Parameters\n    ----------\n    mod: runtime.Module\n        The module of interest to be validated.\n\n    device: runtime.Device\n        The device to run the target module on.\n\n    time_eval: bool\n        Whether to time evaluate the functions.\n\n    rtol: float\n        rtol used in validation\n\n    atol: float\n        atol used in validation\n    \"\"\"\n\n    def __init__(  # pylint: disable=too-many-arguments, unused-argument\n        self,\n        mod: runtime.Module,\n        device: runtime.Device,\n        debug_out: Path,\n        time_eval: bool = True,\n        rtol: float = 1e-2,\n        atol: float = 1,\n        skip_rounds: int = 0,\n    ):\n        super().__init__(mod, device, True, rtol, atol)\n        self.debug_out = debug_out\n        self.time_eval = time_eval\n        self.time_eval_results: Dict[str, Tuple[float, int]] = {}\n        self.visited: Set[str] = set([])\n        self.skip_rounds = skip_rounds\n        self.counter = 0\n        debug_out.mkdir(exist_ok=True, parents=True)\n\n    def reset(self, debug_out: Path):  # pylint: disable=unused-argument\n        \"\"\"Reset the state of the Instrument class\n\n        Note\n        ----\n        `debug_out` is not used in this class.\n\n        Parameters\n        ----------\n        debug_out : Path\n            the directory to dump the .npz files\n        \"\"\"\n        self.debug_out = debug_out\n        _print_as_table(\n            sorted(\n                self.time_eval_results.items(),\n                key=lambda x: -(x[1][0] * x[1][1]),\n            )\n        )\n        self.time_eval_results = {}\n        self.visited = set([])\n        self.counter = 0\n        debug_out.mkdir(exist_ok=True, parents=True)\n\n    def skip_instrument(self, func, name, before_run, ret_val, *args):\n        if name.startswith(\"shape_func\"):\n            return True\n        if self.counter < self.skip_rounds:\n            self.counter += 1\n            print(f\"[{self.counter}] Skip validating {name}..\")\n            return True\n        if name in self.visited:\n            if self.time_eval and name in self.time_eval_results:\n                record = self.time_eval_results[name]\n                self.time_eval_results[name] = (record[0], record[1] + 1)\n            return True\n        self.visited.add(name)\n        return False\n\n    def compare(\n        self,\n        name: str,\n        ref_args: List[tvm.runtime.Tensor],\n        new_args: List[tvm.runtime.Tensor],\n        ret_indices: List[int],\n    ):\n        super().compare(name, ref_args, new_args, ret_indices)\n\n        if self.time_eval and name not in self.time_eval_results:\n            res = self.mod.time_evaluator(\n                name,\n                self.device,\n                number=20,\n                repeat=3,\n                min_repeat_ms=100,\n                # cache_flush_bytes=256 * 10**6\n            )(*new_args)\n            self.time_eval_results[name] = (res.mean, 1)\n            print(f\"Time-eval result {name} on {self.device}:\\n {res}\")\n\n\ndef get_instrument(args):\n    \"\"\"Get the debug instrument from the CLI arguments\"\"\"\n    if args.cmp_device is None:\n        assert args.cmp_lib_path is None, \"cmp_lib_path must be None if cmp_device is None\"\n        args.cmp_device = args.device\n        args.cmp_lib_path = args.model_lib\n\n    if args.cmp_device == \"iphone\":\n        assert args.cmp_lib_path.endswith(\".dylib\"), \"Require a dylib file for iPhone\"\n        proxy_host = os.environ.get(\"TVM_RPC_PROXY_HOST\", \"127.0.0.1\")\n        proxy_port = int(os.environ.get(\"TVM_RPC_PROXY_PORT\", \"9090\"))\n        sess = rpc.connect(proxy_host, proxy_port, \"iphone\")\n        sess.upload(args.cmp_lib_path)\n        lib = sess.load_module(os.path.basename(args.cmp_lib_path))\n        cmp_device = sess.metal()\n    elif args.cmp_device == \"android\":\n        assert args.cmp_lib_path.endswith(\".so\"), \"Require a so file for Android\"\n        tracker_host = os.environ.get(\"TVM_TRACKER_HOST\", \"0.0.0.0\")\n        tracker_port = int(os.environ.get(\"TVM_TRACKER_PORT\", \"9190\"))\n        tracker = rpc.connect_tracker(tracker_host, tracker_port)\n        sess = tracker.request(\"android\")\n        sess.upload(args.cmp_lib_path)\n        lib = sess.load_module(os.path.basename(args.cmp_lib_path))\n        cmp_device = sess.cl(0)\n    else:\n        lib = tvm.runtime.load_module(args.cmp_lib_path)\n        cmp_device = tvm.device(args.cmp_device)\n\n    return LibCompare(\n        lib,\n        cmp_device,\n        time_eval=args.time_eval,\n        debug_out=Path(args.debug_dir),\n    )\n\n\ndef main():\n    \"\"\"The main function to start a DebugChat CLI\"\"\"\n\n    parser = ArgumentParser(\"MLC LLM Chat Debug Tool\")\n    parser.add_argument(\n        \"prompt\",\n        type=str,\n        help=\"The user input prompt.\",\n    )\n    parser.add_argument(\n        \"--generate-len\",\n        type=int,\n        help=\"Number of output tokens to generate.\",\n        required=True,\n    )\n    parser.add_argument(\n        \"--model\",\n        type=str,\n        help=\"An MLC model directory that contains `mlc-chat-config.json`\",\n        required=True,\n    )\n    parser.add_argument(\n        \"--model-lib\",\n        type=str,\n        help=\"The full path to the model library file to use (e.g. a ``.so`` file).\",\n        required=True,\n    )\n    parser.add_argument(\n        \"--debug-dir\",\n        type=str,\n        help=\"The output folder to store the dumped debug files.\",\n        required=True,\n    )\n    parser.add_argument(\n        \"--device\",\n        type=str,\n        default=\"auto\",\n        help=HELP[\"device_compile\"] + ' (default: \"%(default)s\")',\n    )\n    parser.add_argument(\n        \"--cmp-device\",\n        type=str,\n        default=\"none\",\n    )\n    parser.add_argument(\n        \"--cmp-lib-path\",\n        type=str,\n        default=\"none\",\n    )\n    parser.add_argument(\n        \"--time-eval\",\n        action=\"store_true\",\n        help=\"Whether to time evaluate the functions.\",\n    )\n    parsed = parser.parse_args()\n    instrument = get_instrument(parsed)\n    debug_chat = DebugChat(\n        model=parsed.model,\n        model_lib=parsed.model_lib,\n        debug_dir=Path(parsed.debug_dir),\n        device=parsed.device,\n        debug_instrument=instrument,\n    )\n    debug_chat.generate(parsed.prompt, parsed.generate_len)\n    # Only print decode for now\n    _print_as_table(\n        sorted(\n            instrument.time_eval_results.items(),\n            key=lambda x: -(x[1][0] * x[1][1]),\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "python/mlc_llm/testing/pytest_utils.py",
    "content": "\"\"\"Extra utilities to mark tests\"\"\"\n\nimport functools\nimport inspect\nfrom pathlib import Path\nfrom typing import Callable\n\nimport pytest\n\nfrom mlc_llm.support.constants import MLC_TEST_MODEL_PATH\n\n\ndef require_test_model(*models: str):\n    \"\"\"Testcase decorator to require a model\n\n    Examples\n    --------\n    .. code::\n\n        @require_test_model(\"Llama-2-7b-chat-hf-q4f16_1-MLC\")\n        def test_reload_reset_unload(model):\n            # model now points to the right path\n            # specified by MLC_TEST_MODEL_PATH\n            engine = mlc_llm.MLCEngine(model)\n            # test code follows\n\n    Parameters\n    ----------\n    models : List[str]\n        The model directories or URLs.\n    \"\"\"\n    model_paths = []\n    missing_models = []\n\n    for model in models:\n        model_path = None\n        for base_path in MLC_TEST_MODEL_PATH:\n            if (base_path / model / \"mlc-chat-config.json\").is_file():\n                model_path = base_path / model\n                break\n        if model_path is None and (Path(model) / \"mlc-chat-config.json\").is_file():\n            model_path = Path(model)\n\n        if model_path is None:\n            missing_models.append(model)\n        else:\n            model_paths.append(str(model_path))\n\n    message = (\n        f\"Model {', '.join(missing_models)} not found in candidate paths \"\n        f\"{[str(p) for p in MLC_TEST_MODEL_PATH]},\"\n        \" if you set MLC_TEST_MODEL_PATH, please ensure model paths are in the right location,\"\n        \" by default we reuse cache, try to run mlc_llm chat to download right set of models.\"\n    )\n\n    def _decorator(func: Callable[..., None]):\n        wrapped = functools.partial(func, *model_paths)\n        wrapped.__name__ = func.__name__  # type: ignore\n\n        if inspect.iscoroutinefunction(wrapped):\n            # The function is a coroutine function (\"async def func(...)\")\n            @functools.wraps(wrapped)\n            async def wrapper(*args, **kwargs):\n                if len(missing_models) > 0:\n                    print(f\"{message} skipping...\")\n                    return\n                await wrapped(*args, **kwargs)\n\n        else:\n            # The function is a normal function (\"def func(...)\")\n            @functools.wraps(wrapped)\n            def wrapper(*args, **kwargs):\n                if len(missing_models) > 0:\n                    print(f\"{message} skipping...\")\n                    return\n                wrapped(*args, **kwargs)\n\n        return pytest.mark.skipif(len(missing_models) > 0, reason=message)(wrapper)\n\n    return _decorator\n\n\ndef require_test_tokenizers(*models: str):\n    \"\"\"Testcase decorator to require a path to tokenizers\"\"\"\n    # redirect to require models for now\n    return require_test_model(*models)\n"
  },
  {
    "path": "python/mlc_llm/tokenizers/__init__.py",
    "content": "\"\"\"Namespace for tokenizer rleated utilities\"\"\"\n\nfrom .streamer import StopStrHandler, TextStreamer\nfrom .tokenizers import Tokenizer\n"
  },
  {
    "path": "python/mlc_llm/tokenizers/_ffi_api.py",
    "content": "\"\"\"FFI APIs for mlc_llm\"\"\"\n\nimport tvm_ffi\n\n# Exports functions registered via TVM_FFI_REGISTER_GLOBAL with the \"mlc\" prefix.\n# e.g. TVM_FFI_REGISTER_GLOBAL(\"mlc.Tokenizer\")\ntvm_ffi.init_ffi_api(\"mlc.tokenizers\", __name__)  # pylint: disable=protected-access\n"
  },
  {
    "path": "python/mlc_llm/tokenizers/streamer.py",
    "content": "\"\"\"Streamers in MLC LLM.\"\"\"\n\nfrom typing import List, Union\n\nimport tvm_ffi\nfrom tvm.runtime import Object, ShapeTuple\n\nfrom . import _ffi_api\nfrom .tokenizers import Tokenizer\n\n\n@tvm_ffi.register_object(\"mlc.TextStreamer\")  # pylint: disable=protected-access\nclass TextStreamer(Object):\n    \"\"\"The class that streams back validated utf-8 text strings\n    that generated by tokenizer.\n    \"\"\"\n\n    def __init__(self, tokenizer: Tokenizer) -> None:  # pylint: disable=super-init-not-called\n        \"\"\"Create the text streamer from tokenizer\"\"\"\n        self.__init_handle_by_constructor__(\n            _ffi_api.TextStreamer,  # type: ignore[attr-defined]  # pylint: disable=no-member\n            tokenizer,  # type: ignore\n        )\n\n    def put(self, delta_tokens: Union[List[int], ShapeTuple]) -> str:\n        \"\"\"Put new delta tokens into the streamer, and get the UTF-8-valid\n        delta string. The text streamer may hold some of the input delta tokens\n        which cannot decode into valid UTF-8 strings. The returned string\n        is always guaranteed to be UTF-8 valid.\n\n        Parameters\n        ----------\n        delta_tokens : Union[List[int], ShapeTuple]\n            The new tokens to put into the streamer.\n\n        Returns\n        -------\n        delta_text : str\n            The decoded delta string after putting the input new tokens.\n        \"\"\"\n        if isinstance(delta_tokens, list):\n            delta_tokens = ShapeTuple(delta_tokens)\n        return _ffi_api.TextStreamerPut(  # type: ignore  # pylint: disable=no-member\n            self, delta_tokens\n        )\n\n    def finish(self) -> str:\n        \"\"\"Return the string decoded by remaining tokens.\"\"\"\n        return _ffi_api.TextStreamerFinish(self)  # type: ignore  # pylint: disable=no-member\n\n\n@tvm_ffi.register_object(\"mlc.StopStrHandler\")  # pylint: disable=protected-access\nclass StopStrHandler(Object):\n    \"\"\"The stop string handler in MLC LLM, which takes input delta tokens\n    one at a time, and return the output delta token before stopping due to\n    stop strings.\"\"\"\n\n    def __init__(  # pylint: disable=super-init-not-called\n        self, stop_strs: List[str], tokenizer: Tokenizer\n    ) -> None:\n        self.__init_handle_by_constructor__(\n            _ffi_api.StopStrHandler,  # type: ignore  # pylint: disable=no-member\n            stop_strs,\n            tokenizer,\n        )\n\n    def put(self, token_id: int) -> List[int]:\n        \"\"\"Add new input delta token to the handler, return output\n        delta tokens before stopping. The stop string handler may hold\n        some of the input delta token which may be part of a stop string.\n        The returned tokens are always guaranteed not to be part of stop string.\n        \"\"\"\n        return list(\n            _ffi_api.StopStrHandlerPut(self, token_id)  # type: ignore  # pylint: disable=no-member\n        )\n\n    def finish(self) -> List[int]:\n        \"\"\"Stop string handling has finished, return remaining cached token ids.\"\"\"\n        return list(\n            _ffi_api.StopStringHandlerFinish(self)  # type: ignore  # pylint: disable=no-member\n        )\n\n    @property\n    def stop_triggered(self) -> bool:\n        \"\"\"Check if the generation has stopped due to stop string.\"\"\"\n        return _ffi_api.StopStrHandlerStopTriggered(self)  # type: ignore  # pylint: disable=no-member\n"
  },
  {
    "path": "python/mlc_llm/tokenizers/tokenizers.py",
    "content": "\"\"\"The tokenizer and related tools in MLC LLM.\nThis tokenizer essentially wraps and binds the HuggingFace tokenizer\nlibrary and sentencepiece.\nReference: https://github.com/mlc-ai/tokenizers-cpp\n\"\"\"\n\nimport json\nfrom dataclasses import asdict, dataclass\nfrom typing import List, Literal\n\nimport tvm\nimport tvm_ffi\nfrom tvm.runtime import Object\n\nfrom . import _ffi_api\n\n\n@dataclass\nclass TokenizerInfo:  # pylint: disable=too-many-instance-attributes\n    \"\"\"Useful information of the tokenizer during generation.\n\n    Attributes\n    ----------\n    token_postproc_method : Literal[\"byte_fallback\", \"byte_level\"]\n        The method to post-process the tokens to their original strings.\n        Possible values (each refers to a kind of tokenizer):\n        - \"byte_fallback\": The same as the byte-fallback BPE tokenizer, including LLaMA-2,\n            Mixtral-7b, etc. E.g. \"▁of\" -> \" of\", \"<0x1B>\" -> \"\\x1b\".\n            This method:\n            1) Transform tokens like <0x1B> to hex char byte 1B. (so-called byte-fallback)\n            2) Replace \\\\u2581 \"▁\" with space.\n        - \"byte_level\": The same as the byte-level BPE tokenizer, including LLaMA-3, GPT-2,\n            Phi-2, etc. E.g. \"Ġin\" -> \" in\", \"ě\" -> \"\\x1b\"\n            This method inverses the bytes-to-unicode transformation in the encoding process in\n            https://github.com/huggingface/transformers/blob/87be06ca77166e6a6215eee5a990ab9f07238a18/src/transformers/models/gpt2/tokenization_gpt2.py#L38-L59\n\n    prepend_space_in_encode : bool\n        Whether to prepend a space during encoding.\n\n    strip_space_in_decode : bool\n        Whether to strip the first space during decoding.\n    \"\"\"\n\n    token_postproc_method: Literal[\"byte_fallback\", \"byte_level\"] = \"byte_fallback\"\n    prepend_space_in_encode: bool = False\n    strip_space_in_decode: bool = False\n\n    def asjson(self) -> str:\n        \"\"\"Return the config in string of JSON format.\"\"\"\n        return json.dumps(asdict(self))\n\n    @staticmethod\n    def from_json(json_str: str) -> \"TokenizerInfo\":\n        \"\"\"Construct a config from JSON string.\"\"\"\n        return TokenizerInfo(**json.loads(json_str))\n\n\n@tvm_ffi.register_object(\"mlc.Tokenizer\")  # pylint: disable=protected-access\nclass Tokenizer(Object):\n    \"\"\"The tokenizer class in MLC LLM.\"\"\"\n\n    def __init__(self, tokenizer_path: str) -> None:  # pylint: disable=super-init-not-called\n        \"\"\"Create the tokenizer from tokenizer directory path.\"\"\"\n        self.__init_handle_by_constructor__(\n            _ffi_api.Tokenizer,  # type: ignore[attr-defined]  # pylint: disable=no-member\n            tokenizer_path,  # type: ignore\n        )\n\n    def encode(self, text: str) -> List[int]:\n        \"\"\"Encode text into ids.\n\n        Parameters\n        ----------\n        text : str\n            The text string to encode.\n\n        Returns\n        -------\n        token_ids : List[int]\n            The list of encoded token ids.\n        \"\"\"\n        return list(_ffi_api.TokenizerEncode(self, text))  # type: ignore  # pylint: disable=no-member\n\n    def encode_batch(self, texts: List[str]) -> List[List[int]]:\n        \"\"\"Encode a batch of texts into ids.\n\n        Parameters\n        ----------\n        texts : List[str]\n            The list of text strings to encode.\n\n        Returns\n        -------\n        token_ids : List[List[int]]\n            The list of list of encoded token ids.\n        \"\"\"\n        return list(_ffi_api.TokenizerEncodeBatch(self, texts))  # type: ignore  # pylint: disable=no-member\n\n    def decode(self, token_ids: List[int]) -> str:\n        \"\"\"Decode token ids into text.\n\n        Parameters\n        ----------\n        token_ids : List[int]\n            The token ids to decode to string.\n\n        Returns\n        -------\n        text : str\n            The decoded text string.\n        \"\"\"\n        return _ffi_api.TokenizerDecode(  # type: ignore  # pylint: disable=no-member\n            self, tvm.runtime.ShapeTuple(token_ids)\n        )\n\n    @staticmethod\n    def detect_tokenizer_info(tokenizer_path: str) -> TokenizerInfo:\n        \"\"\"Detect the tokenizer info from the given path of the tokenizer.\n\n        Parameters\n        ----------\n        tokenizer_path : str\n            The tokenizer directory path.\n\n        Returns\n        -------\n        tokenizer_info : str\n            The detected tokenizer info in JSON string.\n        \"\"\"\n        return TokenizerInfo.from_json(_ffi_api.DetectTokenizerInfo(tokenizer_path))  # type: ignore  # pylint: disable=no-member\n"
  },
  {
    "path": "python/requirements.txt",
    "content": "apache-tvm-ffi\ndatasets\nfastapi\nflashinfer-python\nml_dtypes>=0.5.1\nopenai\npandas\nprompt_toolkit\nrequests\nsafetensors\nsentencepiece\nshortuuid\ntiktoken\ntorch\ntqdm\ntransformers\nuvicorn\n"
  },
  {
    "path": "python/setup.py",
    "content": "# pylint: disable=invalid-name, exec-used\n\"\"\"Setup MLC LLM package.\"\"\"\n\nimport os\nimport shutil\n\nfrom setuptools import find_packages, setup\nfrom setuptools.dist import Distribution\n\nCURRENT_DIR = os.path.dirname(__file__)\nCONDA_BUILD = os.getenv(\"CONDA_BUILD\") is not None\n\n\ndef get_lib_path():\n    \"\"\"Get library path, name and version\"\"\"\n    # Directly exec libinfo to get the right setup\n    libinfo_py = os.path.join(CURRENT_DIR, \"./mlc_llm/libinfo.py\")\n    libinfo = {\"__file__\": libinfo_py}\n    with open(libinfo_py, \"rb\") as f:\n        exec(compile(f.read(), libinfo_py, \"exec\"), libinfo, libinfo)\n    version = libinfo[\"__version__\"]\n\n    # conda installs libraries into env instead of packaging with pip\n    if not CONDA_BUILD:\n        libs = [\n            libinfo[\"find_lib_path\"](\"mlc_llm\")[0],\n            libinfo[\"find_lib_path\"](\"mlc_llm_module\")[0],\n        ]\n    else:\n        libs = None\n\n    return libs, version\n\n\ndef git_describe_version(original_version):\n    \"\"\"Get git describe version.\"\"\"\n    ver_py = os.path.join(CURRENT_DIR, \"..\", \"version.py\")\n    libver = {\"__file__\": ver_py}\n    with open(ver_py, \"rb\") as f:\n        exec(compile(f.read(), ver_py, \"exec\"), libver, libver)\n    _, gd_version = libver[\"git_describe_version\"]()\n    if gd_version is not None and gd_version != original_version:\n        print(f\"Use git describe based version {gd_version}\")\n    if gd_version is None:\n        print(f\"Use original version {original_version}\")\n        return original_version\n    return gd_version\n\n\ndef parse_requirements(filename: os.PathLike):\n    \"\"\"Parse requirements.txt.\"\"\"\n    with open(filename, encoding=\"utf-8\") as f:\n        requirements = f.read().splitlines()\n\n        def extract_url(line):\n            return next(filter(lambda x: x[0] != \"-\", line.split()))\n\n        extra_URLs = []\n        deps = []\n        for line in requirements:\n            if line.startswith((\"#\", \"-r\")):\n                continue\n\n            # handle -i and --extra-index-url options\n            if \"-i \" in line or \"--extra-index-url\" in line:\n                extra_URLs.append(extract_url(line))\n            else:\n                deps.append(line)\n    return deps, extra_URLs\n\n\nLIB_LIST, __version__ = get_lib_path()\n__version__ = git_describe_version(__version__)\n\n\nclass BinaryDistribution(Distribution):\n    \"\"\"This class is needed in order to create OS specific wheels.\"\"\"\n\n    def has_ext_modules(self):\n        \"\"\"Return True for binary distribution.\"\"\"\n        return True\n\n    def is_pure(self):\n        \"\"\"Return False for binary distribution.\"\"\"\n        return False\n\n\ndef main():\n    \"\"\"The main entrypoint.\"\"\"\n    setup_kwargs = {}\n    if not CONDA_BUILD:\n        with open(\"MANIFEST.in\", \"w\", encoding=\"utf-8\") as fo:\n            for path in LIB_LIST:\n                if os.path.isfile(path):\n                    shutil.copy(path, os.path.join(CURRENT_DIR, \"mlc_llm\"))\n                    _, libname = os.path.split(path)\n                    fo.write(f\"include mlc_llm/{libname}\\n\")\n        setup_kwargs = {\"include_package_data\": True}\n\n    setup(\n        name=\"mlc_llm\",\n        version=__version__,\n        description=\"MLC LLM: an universal LLM deployment engine via ML compilation.\",\n        url=\"https://llm.mlc.ai/\",\n        author=\"MLC LLM Contributors\",\n        license=\"Apache 2.0\",\n        # See https://pypi.org/classifiers/\n        classifiers=[\n            \"License :: OSI Approved :: Apache Software License\",\n            \"Development Status :: 4 - Beta\",\n            \"Intended Audience :: Developers\",\n            \"Intended Audience :: Education\",\n            \"Intended Audience :: Science/Research\",\n        ],\n        keywords=\"machine learning\",\n        zip_safe=False,\n        packages=find_packages(),\n        entry_points={\n            \"console_scripts\": [\"mlc_llm = mlc_llm.__main__:main\"],\n        },\n        package_dir={\"mlc_llm\": \"mlc_llm\"},\n        install_requires=parse_requirements(\"requirements.txt\")[0],\n        distclass=BinaryDistribution,\n        **setup_kwargs,\n    )\n\n    def _remove_path(path):\n        if os.path.exists(path):\n            if os.path.isfile(path):\n                os.remove(path)\n            elif os.path.isdir(path):\n                shutil.rmtree(path)\n\n    if not CONDA_BUILD:\n        # Wheel cleanup\n        os.remove(\"MANIFEST.in\")\n        for path in LIB_LIST:\n            _, libname = os.path.split(path)\n            _remove_path(f\"mlc_llm/{libname}\")\n\n\nmain()\n"
  },
  {
    "path": "scripts/build_mlc_for_docs.sh",
    "content": "#!/bin/bash\nset -euxo pipefail\n\nmkdir -p build\ncd build\ncmake .. -DCMAKE_POLICY_VERSION_MINIMUM=3.5\nmake -j$(nproc)\ncd -\n"
  },
  {
    "path": "scripts/build_site.sh",
    "content": "#!/bin/bash\nset -euxo pipefail\n\nexport PYTHONPATH=$PWD/python\ncd docs && make html && cd ..\n\ncd site && jekyll b && cd ..\n\nrm -rf site/_site/docs\ncp -r docs/_build/html site/_site/docs\n"
  },
  {
    "path": "scripts/check_url_validity.py",
    "content": "import argparse\nimport re\nfrom pathlib import Path\n\nimport requests\n\n\ndef find_urls_in_file(file_path):\n    with open(file_path, \"r\") as file:\n        content = file.read()\n\n    # Regular expression pattern to match URLs\n    url_pattern = re.compile(\n        r\"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\\\(\\\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+\"\n    )\n\n    # Find all matches of URLs in the content\n    urls = re.findall(url_pattern, content)\n    return [url.strip(\">\") for url in urls]\n\n\ndef main():\n    parser = argparse.ArgumentParser(description=\"Check validity of links in documentation\")\n    parser.add_argument(\"--directory\", type=str, default=\"docs\", help=\"Directory of documentation.\")\n    args = parser.parse_args()\n\n    # traversal the directory and find all rst files\n    doc_directory = Path(args.directory)\n    for file_path in doc_directory.glob(\"**/*.rst\"):\n        print(\"Checking {}...\".format(file_path))\n        for url in find_urls_in_file(file_path):\n            try:\n                r = requests.get(url)\n                if r.status_code == 404:\n                    print(\"404 not found: {}\".format(url))\n            except Exception as e:\n                print(\"Error connecting {}, error: {}\".format(url, e))\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "scripts/gh_deploy_site.sh",
    "content": "#!/bin/bash\n# NOTE: this script is triggered by github action automatically\n# when megred into main\n\nset -euxo pipefail\n\nscripts/build_mlc_for_docs.sh\nscripts/build_site.sh\n\ngit fetch\ngit checkout -B gh-pages origin/gh-pages\nrm -rf docs .gitignore\nmkdir -p docs\ncp -rf site/_site/* docs\ntouch docs/.nojekyll\n\nDATE=`date`\ngit add docs && git commit -am \"Build at ${DATE}\"\ngit push origin gh-pages\ngit checkout main && git submodule update\necho \"Finish deployment at ${DATE}\"\n"
  },
  {
    "path": "scripts/local_deploy_site.sh",
    "content": "#!/bin/bash\n# NOTE: use this script to check local site\n\nset -euxo pipefail\n\nscripts/build_site.sh\n\ncd site && jekyll serve  --skip-initial-build --host localhost --baseurl / --port 8888\n"
  },
  {
    "path": "site/.gitignore",
    "content": "dist\nllm-chat-config.json\n_includes/stable_diffusion.html\n_site\n.jekyll-cache\n"
  },
  {
    "path": "site/CNAME",
    "content": "llm.mlc.ai\n"
  },
  {
    "path": "site/Gemfile",
    "content": "# frozen_string_literal: true\n\nsource \"https://rubygems.org\"\n\n# gem \"rails\"\ngem \"jekyll-remote-theme\"\ngem \"jekyll-sass-converter\"\n"
  },
  {
    "path": "site/_config.yml",
    "content": "name: \"MLC LLM\"\nshort_name: \"MLC LLM\"\n\nurl: https://llm.mlc.ai/\n\nexclude: [README.md, serve_local.sh]\n\nplugins:\n  - jekyll-remote-theme\n\nremote_theme: mlc-ai/jekyll-theme-mlc\n\n\n# Colorize code snippets with the rogue module if we want to deploy on GH.\nhighlighter: rouge\n\nmarkdown: kramdown\n\n# The path structure for blog posts.\npermalink: /blog/:year/:month/:day/:title.html\n\n# Number of news stories on the front page.\nfront_page_news: 8\n\n# Base pathname for links.\nbase: ''\n\n# make pages for the _projects folder\ncollections:\n  projects:\n    output: true\n\ncourse_title:\n\n# Navigation bar links.\nnavigation:\n  - title: Home\n    link: /\n  - title: Docs\n    link: /docs\n  - title: Github\n    link: https://github.com/mlc-ai/mlc-llm\n"
  },
  {
    "path": "site/_includes/head.html",
    "content": "<meta name=\"description\" content=\"WebLLM: High-Performance In-Browser LLM Inference Engine\">\n<meta\n  http-equiv=\"origin-trial\"\n  content=\"Agx76XA0ITxMPF0Z8rbbcMllwuxsyp9qdtQaXlLqu1JUrdHB6FPonuyIKJ3CsBREUkeioJck4nn3KO0c0kkwqAMAAABJeyJvcmlnaW4iOiJodHRwOi8vbG9jYWxob3N0Ojg4ODgiLCJmZWF0dXJlIjoiV2ViR1BVIiwiZXhwaXJ5IjoxNjkxNzExOTk5fQ==\"\n/>\n<meta\n  http-equiv=\"origin-trial\"\n  content=\"AnmwqQ1dtYDQTYkZ5iMtHdINCaxjE94uWQBKp2yOz1wPTcjSRtOHUGQG+r2BxsEuM0qhxTVnuTjyh31HgTeA8gsAAABZeyJvcmlnaW4iOiJodHRwczovL21sYy5haTo0NDMiLCJmZWF0dXJlIjoiV2ViR1BVIiwiZXhwaXJ5IjoxNjkxNzExOTk5LCJpc1N1YmRvbWFpbiI6dHJ1ZX0=\"\n/>\n<script src=\"https://code.jquery.com/jquery-3.6.3.min.js\" integrity=\"sha256-pvPw+upLPUjgMXY0G+8O0xUf+/Im1MZjXxxgOcBQBXU=\" crossorigin=\"anonymous\"></script>\n<link rel=\"stylesheet\" href=\"{{ '/assets/css/hero.css' | relative_url }}\" />\n"
  },
  {
    "path": "site/_includes/hero.html",
    "content": "<section id=\"hero\">\n  <div class=\"heading-container\">\n    <h1>MLC LLM: Universal LLM Deployment Engine With ML Compilation</h1>\n    <div class=\"link-container\">\n      <a class=\"github-link\" href=\"https://github.com/mlc-ai/mlc-llm\">\n        <span class=\"github-link-content\">\n          <span class=\"icon\">{% include github.svg %}</span>\n          <span>GitHub</span>\n          <span class=\"arrow-container\">{% include arrow.svg %}</span>\n          </span>\n      </a>\n      <a class=\"get-start-link moving-border\" href=\"https://llm.mlc.ai/docs/get_started/quick_start\">\n        <span class=\"border\"></span>\n        <span class=\"get-start-link-content\">\n          <span>Get Started</span>\n          <span class=\"arrow-container\">{% include arrow.svg %}</span>\n          </span>\n      </a>\n    </div>\n  </div>\n  <div class=\"demo-container\">\n    <!-- <img class=\"android\" src=\"/assets/gif/android-demo.gif\" alt=\"Android Demo\" width=\"612\" height=\"1334\" />\n    <img class=\"linux\" src=\"/assets/gif/linux-demo.gif\" alt=\"Linux Demo\" width=\"1089\" height=\"667\" />\n    <img class=\"ios\" src=\"/assets/gif/ios-demo.gif\" alt=\"iOS Demo\" width=\"640\" height=\"1394\" /> -->\n    <!-- <img src=\"https://llm.mlc.ai/docs/_images/project-workflow.svg\" alt=\"MLC LLM Architecture\" /> -->\n    {% include project-workflow.svg %}\n  </div>\n</section>\n\n<script>\n  (function() {\n\n  function handlerIn(e) {\n    $(this).addClass(\"expanded\");\n  }\n  function handlerOut(e) {\n    $(this).removeClass(\"expanded\");\n  }\n\n  $(\".chat-link\").hover(handlerIn, handlerOut);\n  $(\".github-link\").hover(handlerIn, handlerOut);\n})()\n</script>\n"
  },
  {
    "path": "site/assets/css/hero.scss",
    "content": "---\n---\n\n#hero {\n    background: radial-gradient(100% 50rem at center 50rem, #3351cb50, #ffffff);\n    padding: 3rem;\n    width: 100vw;\n    margin-left: calc(50% - 50vw);\n    margin-top: -20px;\n    display: flex;\n    flex-direction: column;\n    align-items: center;\n\n    a {\n        color: black;\n    }\n\n    .heading-container {\n        display: flex;\n        flex-direction: column;\n        align-items: center;\n        font-family: \"Mona Sans\", \"MonaSansFallback\", -apple-system, BlinkMacSystemFont, \"Segoe UI\", Helvetica, Arial, sans-serif, \"Apple Color Emoji\", \"Segoe UI Emoji\";\n        margin: auto;\n\n        a {\n            min-width: fit-content;\n            max-width: 16rem;\n            flex-grow: 1;\n        }\n\n        h1 {\n            text-align: center;\n            font-size: 2rem;\n            font-weight: 700;\n        }\n\n        .link-container {\n            display: flex;\n            margin-top: 2rem;\n            align-items: center;\n            flex-wrap: wrap;\n            font-size: 1rem;\n            word-break: keep-all;\n            font-weight: 600;\n            gap: 1rem;\n            justify-content: center;\n\n            .github-link {\n                display: inline-flex;\n                gap: 1rem;\n                border-radius: 9999px;\n                vertical-align: middle;\n                align-items: center;\n                justify-content: center;\n                text-decoration: none;\n                cursor: pointer;\n                height: fit-content;\n                // padding: .25rem;\n\n                .github-link-content {\n                    width: 100%;\n                    height: 100%;\n                    z-index: 1;\n                    border-radius: 9999px;\n                    padding: 1rem 1.75rem;\n                    background-color: #000000;\n                    display: inline-flex;\n                    gap: .5rem;\n                    display: inline-flex;\n                    justify-content: center;\n                    color: rgb(229 229 229);\n\n                    .icon {\n                        display: inline-flex;\n                        align-items: center;\n                        margin-right: .5rem;\n\n                        svg {\n                            height: 1.5rem;\n                        }\n                    }\n                }\n            }\n\n            .get-start-link {\n                display: inline-flex;\n                gap: 1rem;\n                background-color: white;\n                border-radius: 9999px;\n                vertical-align: middle;\n                align-items: center;\n                justify-content: center;\n                text-decoration: none;\n                cursor: pointer;\n                height: fit-content;\n                padding: .25rem;\n\n                .get-start-link-content {\n                    width: 100%;\n                    height: 100%;\n                    z-index: 1;\n                    border-radius: 9999px;\n                    padding: 1rem 1.75rem;\n                    background-color: white;\n                    display: inline-flex;\n                    justify-content: center;\n                }\n            }\n\n            .arrow-container {\n                margin-left: .25rem;\n                display: inline-flex;\n                align-items: center;\n            }\n        }\n    }\n\n    .arrow-expandable {\n        stroke-dasharray: 10;\n        stroke-dashoffset: 10;\n        transition: stroke-dashoffset 200ms;\n    }\n\n    .expanded {\n        .arrow-expandable {\n            stroke-dashoffset: 20;\n        }\n    }\n\n    .demo-container {\n        position: relative;\n        margin-top: 96px;\n        width: calc(100% + 4rem);\n        max-width: 1024px;\n        flex-shrink: 0;\n        padding: 2rem;\n\n        svg {\n            height: auto;\n            width: 100%;\n            border-radius: inherit;\n        }\n    }\n}\n\n.moving-border {\n    overflow: hidden;\n    position: relative;\n\n    .border {\n        position: absolute;\n        inset: -1000%;\n        animation: spin 3s linear infinite;\n        border-radius: 1rem;\n        background-image: conic-gradient(from 90deg at 50% 50%, #e2cbff 0, #393bb2 50%, #e2cbff 100%);\n    }\n}\n\n@media screen and (min-width:640px) {\n    #hero {\n        padding: 6rem;\n\n        .heading-container {\n            max-width: 40rem;\n\n            h1 {\n                font-size: 3rem;\n            }\n        }\n\n        .demo-container {\n            width: calc(100% + 10rem);\n        }\n    }\n}\n\n\n@media screen and (min-width:768px) {\n    #hero {\n        .heading-container {\n            max-width: 45rem;\n\n            h1 {\n                font-size: 3.2rem;\n            }\n\n            .link-container {\n                font-size: 1.2rem;\n            }\n        }\n    }\n}\n\n@media screen and (min-width:1024px) {\n    #hero {\n        padding: 8rem;\n\n        .heading-container {\n            max-width: 50rem;\n\n            h1 {\n                font-size: 3.5rem;\n            }\n        }\n\n        .demo-container {\n            width: 100%;\n        }\n    }\n\n}\n\n@media screen and (min-width:1280px) {\n    #hero {\n        .heading-container {\n            max-width: 60rem;\n\n            h1 {\n                font-size: 4rem;\n            }\n        }\n    }\n}\n\n@media screen and (min-width:1760px) {\n    #hero {\n        background: radial-gradient(100% 50rem at center 50rem, #3351cb50, #ffffff);\n\n        gap: 4rem;\n        padding-bottom: 12rem;\n    }\n}\n\n@keyframes spin {\n    100% {\n        transform: rotate(1turn);\n    }\n}\n"
  },
  {
    "path": "site/index.md",
    "content": "---\nlayout: default\ntitle: Home\nnotitle: true\n---\n\n{% include hero.html %}\n\n## Overview\n\nMLC LLM is a machine learning compiler and high-performance deployment engine for large language models.  The mission of this project is to enable everyone to develop, optimize, and deploy AI models natively on everyone's platforms. \n\nMLC LLM compiles and runs code on MLCEngine -- a unified high-performance LLM inference engine across the above platforms. MLCEngine provides OpenAI-compatible API available through REST server, python, javascript, iOS, Android, all backed by the same engine and compiler that we keep improving with the community.\n\n## Get Started\n\nPlease visit our [documentation](https://llm.mlc.ai/docs/) to get started with MLC LLM.\n- [Installation](https://llm.mlc.ai/docs/install/mlc_llm)\n- [Quick start](https://llm.mlc.ai/docs/get_started/quick_start)\n- [Introduction](https://llm.mlc.ai/docs/get_started/introduction)\n\n## Links\n- [MLC LLM Github](https://github.com/mlc-ai/mlc-llm)\n- [WebLLM Project](https://webllm.mlc.ai)\n"
  },
  {
    "path": "site/privacy.md",
    "content": "---\nlayout: default\ntitle: Home\nnotitle: true\n---\n\n# MLC Chat App Privacy\n\nMLC Chat run all generation locally.\nAll data stays in users' device and is not collected by the app.\n"
  },
  {
    "path": "tests/README.md",
    "content": "# MLC LLM Tests\n\nWe primarily relies on pytest to test our engine.\nMost of the unit functionalities in C++ can be exposed via TVM FFI,\nand tested through python environment.\n\nWe categorize the test cases by adding `pytestmark = [pytest.mark.category_name]`.\nCheckout [python/conftest.py](python/conftest.py) for categories.\n"
  },
  {
    "path": "tests/cpp/conv_template_unittest.cc",
    "content": "#include \"json_ffi/conv_template.h\"\n\n#include <gtest/gtest.h>\n\nnamespace mlc {\nnamespace llm {\nnamespace json_ffi {\n\nvoid _TestConvTemplateLoadJSONTextContent() {\n  std::string conv_template =\n      \"{\\n\"\n      \"    \\\"name\\\": \\\"test\\\",\\n\"\n      \"    \\\"system_template\\\": \\\"abc{system_message}\\\",\\n\"\n      \"    \\\"system_message\\\": \\\"de\\\",\\n\"\n      \"    \\\"roles\\\": {\\n\"\n      \"      \\\"user\\\": \\\"Instruct\\\",\\n\"\n      \"      \\\"assistant\\\": \\\"Output\\\",\\n\"\n      \"      \\\"tool\\\": \\\"Instruct\\\"\\n\"\n      \"    },\\n\"\n      \"    \\\"role_templates\\\": {\\n\"\n      \"      \\\"user\\\": \\\"{user_message}\\\",\\n\"\n      \"      \\\"assistant\\\": \\\"{assistant_message}\\\",\\n\"\n      \"      \\\"tool\\\": \\\"{tool_message}\\\"\\n\"\n      \"    },\\n\"\n      \"    \\\"messages\\\": [[\\\"Instruct\\\", \\\"Hello\\\"], [\\\"Output\\\", \\\"Hey\\\"]],\\n\"\n      \"    \\\"seps\\\": [\\n\"\n      \"      \\\"\\\\n\\\"\\n\"\n      \"    ],\\n\"\n      \"    \\\"role_content_sep\\\": \\\": \\\",\\n\"\n      \"    \\\"role_empty_sep\\\": \\\":\\\",\\n\"\n      \"    \\\"stop_str\\\": [\\n\"\n      \"      \\\"<|endoftext|>\\\"\\n\"\n      \"    ],\\n\"\n      \"    \\\"add_role_after_system_message\\\": false,\\n\"\n      \"    \\\"stop_token_ids\\\": [\\n\"\n      \"      50256\\n\"\n      \"    ]\"\n      \"}\";\n\n  auto res = Conversation::FromJSON(conv_template).IsOk();\n  ASSERT_TRUE(res);\n  const Conversation& conv = Conversation::FromJSON(conv_template).Unwrap();\n  ASSERT_EQ(conv.name, \"test\");\n  ASSERT_EQ(conv.system_template, \"abc{system_message}\");\n  ASSERT_EQ(conv.system_message, \"de\");\n  ASSERT_EQ(conv.roles.at(\"user\"), \"Instruct\");\n  ASSERT_EQ(conv.roles.at(\"assistant\"), \"Output\");\n  ASSERT_EQ(conv.roles.at(\"tool\"), \"Instruct\");\n  ASSERT_EQ(conv.role_templates.at(\"user\"), \"{user_message}\");\n  ASSERT_EQ(conv.role_templates.at(\"assistant\"), \"{assistant_message}\");\n  ASSERT_EQ(conv.role_templates.at(\"tool\"), \"{tool_message}\");\n  ASSERT_EQ(conv.messages.at(0).role, \"Instruct\");\n  ASSERT_EQ(conv.messages.at(0).content.Text(), \"Hello\");\n  ASSERT_EQ(conv.messages.at(1).role, \"Output\");\n  ASSERT_EQ(conv.messages.at(1).content.Text(), \"Hey\");\n  ASSERT_EQ(conv.seps.at(0), \"\\n\");\n  ASSERT_EQ(conv.role_content_sep, \": \");\n  ASSERT_EQ(conv.role_empty_sep, \":\");\n  ASSERT_EQ(conv.stop_str.at(0), \"<|endoftext|>\");\n  ASSERT_EQ(conv.add_role_after_system_message, false);\n  ASSERT_EQ(conv.stop_token_ids.at(0), 50256);\n}\n\nvoid _TestConvTemplateLoadJSONPartsContent() {\n  std::string conv_template =\n      \"{\\n\"\n      \"    \\\"name\\\": \\\"test\\\",\\n\"\n      \"    \\\"system_template\\\": \\\"abc{system_message}\\\",\\n\"\n      \"    \\\"system_message\\\": \\\"de\\\",\\n\"\n      \"    \\\"roles\\\": {\\n\"\n      \"      \\\"user\\\": \\\"Instruct\\\",\\n\"\n      \"      \\\"assistant\\\": \\\"Output\\\",\\n\"\n      \"      \\\"tool\\\": \\\"Instruct\\\"\\n\"\n      \"    },\\n\"\n      \"    \\\"role_templates\\\": {\\n\"\n      \"      \\\"user\\\": \\\"{user_message}\\\",\\n\"\n      \"      \\\"assistant\\\": \\\"{assistant_message}\\\",\\n\"\n      \"      \\\"tool\\\": \\\"{tool_message}\\\"\\n\"\n      \"    },\\n\"\n      \"    \\\"messages\\\": [[\\\"Instruct\\\", \"\n      \"    [{\\\"type\\\": \\\"text\\\", \\\"text\\\": \\\"What's in the image?\\\"},\\n\"\n      \"     {\\\"type\\\": \\\"image_url\\\", \\\"image_url\\\": \\\"https://example.com/image.jpg\\\"}]\\n\"\n      \"    ]],\\n\"\n      \"    \\\"seps\\\": [\\n\"\n      \"      \\\"\\\\n\\\"\\n\"\n      \"    ],\\n\"\n      \"    \\\"role_content_sep\\\": \\\": \\\",\\n\"\n      \"    \\\"role_empty_sep\\\": \\\":\\\",\\n\"\n      \"    \\\"stop_str\\\": [\\n\"\n      \"      \\\"<|endoftext|>\\\"\\n\"\n      \"    ],\\n\"\n      \"    \\\"add_role_after_system_message\\\": false,\\n\"\n      \"    \\\"stop_token_ids\\\": [\\n\"\n      \"      50256\\n\"\n      \"    ]\"\n      \"}\";\n\n  auto res = Conversation::FromJSON(conv_template).IsOk();\n  ASSERT_TRUE(res);\n  const Conversation& conv = Conversation::FromJSON(conv_template).Unwrap();\n  ASSERT_EQ(conv.name, \"test\");\n  ASSERT_EQ(conv.system_template, \"abc{system_message}\");\n  ASSERT_EQ(conv.system_message, \"de\");\n  ASSERT_EQ(conv.roles.at(\"user\"), \"Instruct\");\n  ASSERT_EQ(conv.roles.at(\"assistant\"), \"Output\");\n  ASSERT_EQ(conv.roles.at(\"tool\"), \"Instruct\");\n  ASSERT_EQ(conv.role_templates.at(\"user\"), \"{user_message}\");\n  ASSERT_EQ(conv.role_templates.at(\"assistant\"), \"{assistant_message}\");\n  ASSERT_EQ(conv.role_templates.at(\"tool\"), \"{tool_message}\");\n  ASSERT_EQ(conv.messages.at(0).role, \"Instruct\");\n  ASSERT_EQ(conv.messages.at(0).content.Parts().at(0).at(\"type\"), \"text\");\n  ASSERT_EQ(conv.messages.at(0).content.Parts().at(0).at(\"text\"), \"What's in the image?\");\n  ASSERT_EQ(conv.messages.at(0).content.Parts().at(1).at(\"type\"), \"image_url\");\n  ASSERT_EQ(conv.messages.at(0).content.Parts().at(1).at(\"image_url\"),\n            \"https://example.com/image.jpg\");\n  ASSERT_EQ(conv.seps.at(0), \"\\n\");\n  ASSERT_EQ(conv.role_content_sep, \": \");\n  ASSERT_EQ(conv.role_empty_sep, \":\");\n  ASSERT_EQ(conv.stop_str.at(0), \"<|endoftext|>\");\n  ASSERT_EQ(conv.add_role_after_system_message, false);\n  ASSERT_EQ(conv.stop_token_ids.at(0), 50256);\n}\n\nTEST(JsonFFIConvTest, LoadJSONTextContentTest) { _TestConvTemplateLoadJSONTextContent(); }\nTEST(JsonFFIConvTest, LoadJSONPartsContentTest) { _TestConvTemplateLoadJSONPartsContent(); }\n\n}  // namespace json_ffi\n}  // namespace llm\n}  // namespace mlc\n"
  },
  {
    "path": "tests/python/__init__.py",
    "content": ""
  },
  {
    "path": "tests/python/compiler_pass/test_fuse_ft_dequantize_matmul_epilogue.py",
    "content": "# pylint: disable=invalid-name,missing-docstring,too-few-public-methods\nimport tvm\nfrom tvm.ir import assert_structural_equal\nfrom tvm.script import ir as I\nfrom tvm.script import relax as R\n\nfrom mlc_llm.compiler_pass.fuse_ft_dequantize_matmul_epilogue import (\n    FuseFTDequantizeEpilogue,\n)\n\n\ndef test_fuse_bias():\n    @I.ir_module\n    class Before:\n        @R.function\n        def main(\n            x: R.Tensor((1, 1, 4096), \"float16\"),\n            weight: R.Tensor((4096, 512), \"int8\"),\n            scale: R.Tensor((1, 1024), \"float16\"),\n            bias: R.Tensor((1, 1, 1024), \"float16\"),\n        ):\n            with R.dataflow():\n                lv1 = R.call_dps_packed(\n                    \"fastertransformer.gemm_fp16_int\",\n                    (\n                        x,\n                        weight,\n                        scale,\n                        \"identity\",\n                        R.prim_value(1),\n                        R.prim_value(1024),\n                        R.prim_value(4096),\n                        R.prim_value(4096),\n                    ),\n                    out_sinfo=R.Tensor((1, 1, 1024), \"float16\"),\n                )\n                lv2 = R.add(lv1, bias)\n                R.output(lv2)\n            return lv2\n\n    @I.ir_module\n    class After:\n        @R.function\n        def main(\n            x: R.Tensor((1, 1, 4096), \"float16\"),\n            weight: R.Tensor((4096, 512), \"int8\"),\n            scale: R.Tensor((1, 1024), \"float16\"),\n            bias: R.Tensor((1, 1, 1024), \"float16\"),\n        ) -> R.Tensor((1, 1, 1024), \"float16\"):\n            with R.dataflow():\n                lv2 = R.call_dps_packed(\n                    \"fastertransformer.gemm_fp16_int_bias\",\n                    (\n                        x,\n                        weight,\n                        scale,\n                        bias,\n                        R.str(\"identity\"),\n                        R.prim_value(1),\n                        R.prim_value(1024),\n                        R.prim_value(4096),\n                        R.prim_value(4096),\n                        R.prim_value(0),\n                    ),\n                    out_sinfo=R.Tensor((1, 1, 1024), \"float16\"),\n                )\n                R.output(lv2)\n            return lv2\n\n    seq = tvm.transform.Sequential([FuseFTDequantizeEpilogue()])\n    mod = seq(Before)\n    assert_structural_equal(mod, After)\n\n\ndef test_fuse_activation():\n    @I.ir_module\n    class Before:\n        @R.function\n        def main(\n            x: R.Tensor((1, 1, 4096), \"float16\"),\n            weight: R.Tensor((4096, 512), \"int8\"),\n            scale: R.Tensor((1, 1024), \"float16\"),\n        ):\n            with R.dataflow():\n                lv1 = R.call_dps_packed(\n                    \"fastertransformer.gemm_fp16_int\",\n                    (\n                        x,\n                        weight,\n                        scale,\n                        \"identity\",\n                        R.prim_value(1),\n                        R.prim_value(1024),\n                        R.prim_value(4096),\n                        R.prim_value(4096),\n                    ),\n                    out_sinfo=R.Tensor((1, 1, 1024), \"float16\"),\n                )\n                lv2 = R.nn.silu(lv1)\n                R.output(lv2)\n            return lv2\n\n    @I.ir_module\n    class After:\n        @R.function\n        def main(\n            x: R.Tensor((1, 1, 4096), \"float16\"),\n            weight: R.Tensor((4096, 512), \"int8\"),\n            scale: R.Tensor((1, 1024), \"float16\"),\n        ) -> R.Tensor((1, 1, 1024), \"float16\"):\n            with R.dataflow():\n                lv2 = R.call_dps_packed(\n                    \"fastertransformer.gemm_fp16_int\",\n                    (\n                        x,\n                        weight,\n                        scale,\n                        R.str(\"silu\"),\n                        R.prim_value(1),\n                        R.prim_value(1024),\n                        R.prim_value(4096),\n                        R.prim_value(4096),\n                    ),\n                    out_sinfo=R.Tensor((1, 1, 1024), \"float16\"),\n                )\n                R.output(lv2)\n            return lv2\n\n    seq = tvm.transform.Sequential([FuseFTDequantizeEpilogue()])\n    mod = seq(Before)\n    assert_structural_equal(mod, After)\n\n\ndef test_fuse_bias_activation():\n    @I.ir_module\n    class Before:\n        @R.function\n        def main(\n            x: R.Tensor((1, 1, 4096), \"float16\"),\n            weight: R.Tensor((4096, 512), \"int8\"),\n            scale: R.Tensor((1, 1024), \"float16\"),\n            bias: R.Tensor((1, 1, 1024), \"float16\"),\n        ):\n            with R.dataflow():\n                lv1 = R.call_dps_packed(\n                    \"fastertransformer.gemm_fp16_int\",\n                    (\n                        x,\n                        weight,\n                        scale,\n                        \"identity\",\n                        R.prim_value(1),\n                        R.prim_value(1024),\n                        R.prim_value(4096),\n                        R.prim_value(4096),\n                    ),\n                    out_sinfo=R.Tensor((1, 1, 1024), \"float16\"),\n                )\n                lv2 = R.add(lv1, bias)\n                lv3 = R.nn.relu(lv2)\n                R.output(lv3)\n            return lv3\n\n    @I.ir_module\n    class After:\n        @R.function\n        def main(\n            x: R.Tensor((1, 1, 4096), \"float16\"),\n            weight: R.Tensor((4096, 512), \"int8\"),\n            scale: R.Tensor((1, 1024), \"float16\"),\n            bias: R.Tensor((1, 1, 1024), \"float16\"),\n        ) -> R.Tensor((1, 1, 1024), \"float16\"):\n            with R.dataflow():\n                lv2 = R.call_dps_packed(\n                    \"fastertransformer.gemm_fp16_int_bias\",\n                    (\n                        x,\n                        weight,\n                        scale,\n                        bias,\n                        R.str(\"relu\"),\n                        R.prim_value(1),\n                        R.prim_value(1024),\n                        R.prim_value(4096),\n                        R.prim_value(4096),\n                        R.prim_value(0),\n                    ),\n                    out_sinfo=R.Tensor((1, 1, 1024), \"float16\"),\n                )\n                R.output(lv2)\n            return lv2\n\n    seq = tvm.transform.Sequential([FuseFTDequantizeEpilogue()])\n    mod = seq(Before)\n    assert_structural_equal(mod, After)\n\n\ndef test_fuse_residual_binary():\n    @I.ir_module\n    class Before:\n        @R.function\n        def main(\n            x: R.Tensor((1, 1, 4096), \"float16\"),\n            weight: R.Tensor((4096, 512), \"int8\"),\n            scale: R.Tensor((1, 1024), \"float16\"),\n            bias: R.Tensor((1, 1, 1024), \"float16\"),\n            residual: R.Tensor((1, 1, 1024), \"float16\"),\n        ):\n            with R.dataflow():\n                lv1 = R.call_dps_packed(\n                    \"fastertransformer.gemm_fp16_int\",\n                    (\n                        x,\n                        weight,\n                        scale,\n                        \"identity\",\n                        R.prim_value(1),\n                        R.prim_value(1024),\n                        R.prim_value(4096),\n                        R.prim_value(4096),\n                    ),\n                    out_sinfo=R.Tensor((1, 1, 1024), \"float16\"),\n                )\n                lv2 = R.add(lv1, bias)\n                lv3 = R.nn.relu(lv2)\n                lv4 = R.multiply(lv3, residual)\n                R.output(lv4)\n            return lv4\n\n    @I.ir_module\n    class After:\n        @R.function\n        def main(\n            x: R.Tensor((1, 1, 4096), \"float16\"),\n            weight: R.Tensor((4096, 512), \"int8\"),\n            scale: R.Tensor((1, 1024), \"float16\"),\n            bias: R.Tensor((1, 1, 1024), \"float16\"),\n            residual: R.Tensor((1, 1, 1024), \"float16\"),\n        ) -> R.Tensor((1, 1, 1024), \"float16\"):\n            with R.dataflow():\n                lv2 = R.call_dps_packed(\n                    \"fastertransformer.gemm_fp16_int_bias_residual\",\n                    (\n                        x,\n                        weight,\n                        scale,\n                        bias,\n                        residual,\n                        R.str(\"relu\"),\n                        R.str(\"multiply\"),\n                        R.str(\"identity\"),\n                        R.prim_value(1),\n                        R.prim_value(1024),\n                        R.prim_value(4096),\n                        R.prim_value(4096),\n                    ),\n                    out_sinfo=R.Tensor((1, 1, 1024), \"float16\"),\n                )\n                R.output(lv2)\n            return lv2\n\n    seq = tvm.transform.Sequential([FuseFTDequantizeEpilogue()])\n    mod = seq(Before)\n    assert_structural_equal(mod, After)\n\n\ndef test_fuse_residual_unary():\n    @I.ir_module\n    class Before:\n        @R.function\n        def main(\n            x: R.Tensor((1, 1, 4096), \"float16\"),\n            weight: R.Tensor((4096, 512), \"int8\"),\n            scale: R.Tensor((1, 1024), \"float16\"),\n            bias: R.Tensor((1, 1, 1024), \"float16\"),\n            residual: R.Tensor((1, 1, 1024), \"float16\"),\n        ):\n            with R.dataflow():\n                lv1 = R.call_dps_packed(\n                    \"fastertransformer.gemm_fp16_int\",\n                    (\n                        x,\n                        weight,\n                        scale,\n                        \"identity\",\n                        R.prim_value(1),\n                        R.prim_value(1024),\n                        R.prim_value(4096),\n                        R.prim_value(4096),\n                    ),\n                    out_sinfo=R.Tensor((1, 1, 1024), \"float16\"),\n                )\n                lv2 = R.add(lv1, bias)\n                lv3 = R.nn.relu(lv2)\n                lv4 = R.add(lv3, residual)\n                lv5 = R.nn.gelu(lv4)\n                R.output(lv5)\n            return lv5\n\n    @I.ir_module\n    class After:\n        @R.function\n        def main(\n            x: R.Tensor((1, 1, 4096), \"float16\"),\n            weight: R.Tensor((4096, 512), \"int8\"),\n            scale: R.Tensor((1, 1024), \"float16\"),\n            bias: R.Tensor((1, 1, 1024), \"float16\"),\n            residual: R.Tensor((1, 1, 1024), \"float16\"),\n        ) -> R.Tensor((1, 1, 1024), \"float16\"):\n            with R.dataflow():\n                lv2 = R.call_dps_packed(\n                    \"fastertransformer.gemm_fp16_int_bias_residual\",\n                    (\n                        x,\n                        weight,\n                        scale,\n                        bias,\n                        residual,\n                        R.str(\"relu\"),\n                        R.str(\"plus\"),\n                        R.str(\"gelu\"),\n                        R.prim_value(1),\n                        R.prim_value(1024),\n                        R.prim_value(4096),\n                        R.prim_value(4096),\n                    ),\n                    out_sinfo=R.Tensor((1, 1, 1024), \"float16\"),\n                )\n                R.output(lv2)\n            return lv2\n\n    seq = tvm.transform.Sequential([FuseFTDequantizeEpilogue()])\n    mod = seq(Before)\n    assert_structural_equal(mod, After)\n\n\nif __name__ == \"__main__\":\n    test_fuse_bias()\n    test_fuse_activation()\n    test_fuse_bias_activation()\n    test_fuse_residual_binary()\n    test_fuse_residual_unary()\n"
  },
  {
    "path": "tests/python/conftest.py",
    "content": "# Licensed to the Apache Software Foundation (ASF) under one\n# or more contributor license agreements.  See the NOTICE file\n# distributed with this work for additional information\n# regarding copyright ownership.  The ASF licenses this file\n# to you under the Apache License, Version 2.0 (the\n# \"License\"); you may not use this file except in compliance\n# with the License.  You may obtain a copy of the License at\n#\n#   http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied.  See the License for the\n# specific language governing permissions and limitations\n# under the License.\n# pylint: disable=missing-module-docstring,unused-import\nimport pytest\n\n\ndef pytest_configure(config):\n    \"\"\"Register markers\"\"\"\n    config.addinivalue_line(\n        \"markers\",\n        \"unittest: unittests for modules, do not require GPU, usually run fast\",\n    )\n    config.addinivalue_line(\"markers\", \"op_correctness: unittest for op corectness, requires GPU\")\n    config.addinivalue_line(\n        \"markers\",\n        (\n            \"engine: testing engine feature functionalities, requires model and GPU, \"\n            \"note: for most request related tests, use endpoint test instead.\"\n        ),\n    )\n    config.addinivalue_line(\n        \"markers\",\n        (\n            \"endpoint: sending requests to a global endpoint fixture(can be an rest or API), \"\n            \"tests compatibilities of API behaviors\"\n        ),\n    )\n    config.addinivalue_line(\n        \"markers\",\n        \"uncategorized: this test is not yet categorized, team should work to categorize it\",\n    )\n"
  },
  {
    "path": "tests/python/conversation_template/test_conversation_protocol.py",
    "content": "import pytest\n\nfrom mlc_llm.conversation_template import ConvTemplateRegistry\nfrom mlc_llm.protocol.conversation_protocol import Conversation, MessagePlaceholders\n\n\ndef get_conv_templates():\n    return [\n        \"llama-3\",\n        \"llama-2\",\n        \"mistral_default\",\n        \"gorilla\",\n        \"gorilla-openfunctions-v2\",\n        \"chatml\",\n        \"phi-2\",\n        \"codellama_completion\",\n        \"codellama_instruct\",\n        \"rwkv_world\",\n    ]\n\n\n@pytest.mark.parametrize(\"conv_template_name\", get_conv_templates())\ndef test_json(conv_template_name):\n    template = ConvTemplateRegistry.get_conv_template(conv_template_name)\n    j = template.to_json_dict()\n    template_parsed = Conversation.from_json_dict(j)\n    assert template == template_parsed\n\n\n@pytest.mark.parametrize(\"conv_template_name\", get_conv_templates())\ndef test_prompt(conv_template_name):\n    conversation = ConvTemplateRegistry.get_conv_template(conv_template_name)\n    user_msg = \"test1\"\n    assistant_msg = \"test2\"\n    prompt = \"test3\"\n\n    expected_user_msg = (\n        conversation.role_templates[\"user\"]\n        .replace(MessagePlaceholders.USER.value, user_msg)\n        .replace(MessagePlaceholders.FUNCTION.value, \"\")\n    )\n\n    expected_prompt = (\n        conversation.role_templates[\"user\"]\n        .replace(MessagePlaceholders.USER.value, prompt)\n        .replace(MessagePlaceholders.FUNCTION.value, \"\")\n    )\n\n    conversation.messages.append((\"user\", user_msg))\n    conversation.messages.append((\"assistant\", assistant_msg))\n    conversation.messages.append((\"user\", prompt))\n    conversation.messages.append((\"assistant\", None))\n    res = conversation.as_prompt()\n\n    system_msg = conversation.system_template.replace(\n        MessagePlaceholders.SYSTEM.value, conversation.system_message\n    )\n    expected_final_prompt = (\n        system_msg\n        + (conversation.seps[0] if system_msg != \"\" else \"\")\n        + (\n            conversation.roles[\"user\"] + conversation.role_content_sep\n            if conversation.add_role_after_system_message\n            else \"\"\n        )\n        + expected_user_msg\n        + conversation.seps[0 % len(conversation.seps)]\n        + conversation.roles[\"assistant\"]\n        + conversation.role_content_sep\n        + assistant_msg\n        + conversation.seps[1 % len(conversation.seps)]\n        + conversation.roles[\"user\"]\n        + conversation.role_content_sep\n        + expected_prompt\n        + conversation.seps[0 % len(conversation.seps)]\n        + conversation.roles[\"assistant\"]\n        + conversation.role_empty_sep\n    )\n    assert res == expected_final_prompt\n\n\nif __name__ == \"__main__\":\n    test_json(\"llama-3\")\n"
  },
  {
    "path": "tests/python/conversation_template/test_llama_template.py",
    "content": "import pytest\n\nfrom mlc_llm.conversation_template import ConvTemplateRegistry\n\npytestmark = [pytest.mark.runtime_unittest]\n\n\n# From the official Llama-3 example:\n# https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/\ndef test_llama3_prompt():\n    conversation = ConvTemplateRegistry.get_conv_template(\"llama-3\")\n    system_msg = \"You are a helpful AI assistant for travel tips and recommendations\"\n    user_msg1 = \"What is France's capital?\"\n    assistant_msg1 = \"Bonjour! The capital of France is Paris!\"\n    user_msg2 = \"What can I do there?\"\n    assistant_msg2 = \"Paris, the City of Light, offers a romantic getaway with must-see attractions like the Eiffel Tower and Louvre Museum, romantic experiences like river cruises and charming neighborhoods, and delicious food and drink options, with helpful tips for making the most of your trip.\"\n    prompt = \"Give me a detailed list of the attractions I should visit, and time it takes in each one, to plan my trip accordingly.\"\n\n    conversation.system_message = system_msg\n    conversation.messages.append((\"user\", user_msg1))\n    conversation.messages.append((\"assistant\", assistant_msg1))\n    conversation.messages.append((\"user\", user_msg2))\n    conversation.messages.append((\"assistant\", assistant_msg2))\n    conversation.messages.append((\"user\", prompt))\n    conversation.messages.append((\"assistant\", None))\n    res = conversation.as_prompt()\n\n    expected = (\n        \"<|start_header_id|>system<|end_header_id|>\\n\\n\"\n        \"You are a helpful AI assistant for travel tips and recommendations<|eot_id|>\\n\"\n        \"<|start_header_id|>user<|end_header_id|>\\n\\n\"\n        \"What is France's capital?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n        \"Bonjour! The capital of France is Paris!<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\n\"\n        \"What can I do there?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n        \"Paris, the City of Light, offers a romantic getaway with must-see attractions like the Eiffel Tower and Louvre Museum, romantic experiences like river cruises and charming neighborhoods, and delicious food and drink options, with helpful tips for making the most of your trip.<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\n\"\n        \"Give me a detailed list of the attractions I should visit, and time it takes in each one, to plan my trip accordingly.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n\"\n    )\n\n    assert res[0] == expected\n\n\nif __name__ == \"__main__\":\n    test_llama3_prompt()\n"
  },
  {
    "path": "tests/python/integration/test_model_compile.py",
    "content": "# pylint: disable=missing-docstring\nimport concurrent.futures as cf\nimport os\nimport shlex\nimport subprocess\nimport sys\nimport tempfile\nfrom itertools import product\n\nimport tvm\n\nfrom mlc_llm.model import MODEL_PRESETS\nfrom mlc_llm.model import MODELS as SUPPORTED_MODELS\nfrom mlc_llm.quantization import QUANTIZATION as SUPPORTED_QUANTS\nfrom mlc_llm.support.constants import MLC_TEMP_DIR\n\nOPT_LEVEL = \"O2\"\nDEVICE2TARGET = {\n    \"cuda\": {\n        \"kind\": \"cuda\",\n        \"arch\": \"sm_86\",\n        \"max_threads_per_block\": 1024,\n        \"max_num_threads\": 1024,\n        \"max_shared_memory_per_block\": 49152,\n        \"thread_warp_size\": 32,\n    },\n    \"rocm\": {\n        \"kind\": \"rocm\",\n        \"mtriple\": \"amdgcn-amd-amdhsa-hcc\",\n        \"mcpu\": \"gfx1100\",\n        \"thread_warp_size\": 32,\n        \"max_threads_per_block\": 1024,\n        \"max_num_threads\": 256,\n        \"max_shared_memory_per_block\": 65536,\n    },\n    \"vulkan\": {\n        \"kind\": \"vulkan\",\n        \"max_threads_per_block\": 1024,\n        \"max_num_threads\": 256,\n        \"max_shared_memory_per_block\": 32768,\n        \"thread_warp_size\": 1,\n        \"supports_float32\": 1,\n        \"supports_float16\": 1,\n        \"supports_int64\": 1,\n        \"supports_int32\": 1,\n        \"supports_int16\": 1,\n        \"supports_int8\": 1,\n        \"supports_16bit_buffer\": 1,\n    },\n    \"metal\": \"metal\",\n    \"wasm\": \"webgpu\",\n    \"android\": \"android\",\n    \"ios\": \"iphone\",\n}\nDEVICE2SUFFIX = {\n    \"cuda\": \"so\",\n    \"rocm\": \"so\",\n    \"vulkan\": \"so\",\n    \"metal\": \"dylib\",\n    \"wasm\": \"wasm\",\n    \"android\": \"tar\",\n    \"ios\": \"tar\",\n}\nMODELS = list(MODEL_PRESETS.keys())\nQUANTS = [  # TODO(@junrushao): use `list(mlc_llm.quantization.QUANTIZATION.keys())`\n    \"q0f16\",\n    \"q0f32\",\n    \"q3f16_1\",\n    \"q4f16_1\",\n    \"q4f32_1\",\n    \"q4f16_ft\",\n]\nTENSOR_PARALLEL_SHARDS = [\n    1,\n]\n\n\ndef run_command(log_file, cmd):\n    with open(log_file, \"w\", encoding=\"utf-8\") as file:\n        subprocess.check_call(\n            cmd,\n            stdout=file,\n            stderr=subprocess.STDOUT,\n        )\n\n\ndef test_model_compile():  # pylint: disable=too-many-locals\n    device = sys.argv[1]\n    num_workers = int(sys.argv[2])\n    target = DEVICE2TARGET[device]\n    if not isinstance(target, str):\n        target = str(tvm.target.Target(target))\n    suffix = DEVICE2SUFFIX[device]\n\n    passed_cmds = []\n    failed_cmds = []\n    with tempfile.TemporaryDirectory(dir=MLC_TEMP_DIR) as tmp_dir:\n        with cf.ProcessPoolExecutor(max_workers=num_workers) as executor:\n            log_files = []\n            cmds = []\n            futures = []\n            for idx, (model, quant, tp_shard) in enumerate(\n                product(\n                    MODELS,\n                    QUANTS,\n                    TENSOR_PARALLEL_SHARDS,\n                )\n            ):\n                if (\n                    SUPPORTED_QUANTS[quant].kind\n                    not in SUPPORTED_MODELS[MODEL_PRESETS[model][\"model_type\"]].quantize\n                ):\n                    continue\n                if not target.startswith(\"cuda\") and quant == \"q4f16_ft\":\n                    # FasterTransformer only works with cuda\n                    continue\n                if \"deepseek_v2\" in model and \"32\" in quant:\n                    # Skip f32 for deepseek v2 model for now.\n                    continue\n                log_file = os.path.join(tmp_dir, f\"lib{idx}.log\")\n                cmd = [\n                    sys.executable,\n                    \"-m\",\n                    \"mlc_llm\",\n                    \"compile\",\n                    model,\n                    \"--quantization\",\n                    quant,\n                    \"--overrides\",\n                    f\"tensor_parallel_shards={tp_shard}\",\n                    \"--device\",\n                    target,\n                    \"--opt\",\n                    OPT_LEVEL,\n                    \"-o\",\n                    os.path.join(tmp_dir, f\"lib{idx}.{suffix}\"),\n                ]\n                future = executor.submit(run_command, log_file, cmd)\n                log_files.append(log_file)\n                cmds.append(cmd)\n                futures.append(future)\n            for log_file, cmd, future in zip(log_files, cmds, futures):\n                cmd = shlex.join(cmd)\n                try:\n                    future.result()\n                    passed_cmds.append(cmd)\n                    print(f\"[PASS] {cmd}\")\n                except Exception:  # pylint: disable=broad-except\n                    failed_cmds.append(cmd)\n                    print(\"-------------------------------\")\n                    print(f\"[FAIL] {cmd}\")\n                    with open(log_file, \"r\", encoding=\"utf-8\") as file:\n                        print(file.read())\n                    print(\"-------------------------------\")\n    print(\"-------------------------------\")\n    print(f\"Total {len(passed_cmds)} passed, {len(failed_cmds)} failed.\")\n    print(\"-------------------------------\")\n    print(\"Passed commands:\")\n    for cmd in passed_cmds:\n        print(cmd)\n    if failed_cmds:\n        print(\"-------------------------------\")\n        print(\"Failed commands:\")\n        for cmd in failed_cmds:\n            print(cmd)\n        sys.exit(1)\n\n\nif __name__ == \"__main__\":\n    test_model_compile()\n"
  },
  {
    "path": "tests/python/json_ffi/test_json_ffi_engine.py",
    "content": "import json\nfrom typing import Dict, List, Optional\n\nimport pytest\nfrom pydantic import BaseModel\n\nfrom mlc_llm.json_ffi import JSONFFIEngine\nfrom mlc_llm.testing import require_test_model\n\n# test category \"engine_feature\"\npytestmark = [pytest.mark.engine_feature]\n\n\nchat_completion_prompts = [\n    \"What is the meaning of life?\",\n    \"Introduce the history of Pittsburgh to me. Please elaborate in detail.\",\n    \"Write a three-day Seattle travel plan. Please elaborate in detail.\",\n    \"What is Alaska famous of? Please elaborate in detail.\",\n    \"What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.\",\n    \"What are the necessary components to assemble a desktop computer? Please elaborate in detail.\",\n    \"Why is Vitamin D important to human beings? Please elaborate in detail.\",\n    \"Where is milk tea originated from? Please elaborate in detail.\",\n    \"Where is the southernmost place in United States? Please elaborate in detail.\",\n    \"Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.\",\n]\n\nfunction_calling_prompts = [\n    \"What is the temperature in Pittsburgh, PA?\",\n    \"What is the temperature in Tokyo, JP?\",\n    \"What is the temperature in Pittsburgh, PA and Tokyo, JP?\",\n]\n\ntools = [\n    {\n        \"type\": \"function\",\n        \"function\": {\n            \"name\": \"get_current_weather\",\n            \"description\": \"Get the current weather in a given location\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"location\": {\n                        \"type\": \"string\",\n                        \"description\": \"The city and state, e.g. San Francisco, CA\",\n                    },\n                    \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n                },\n                \"required\": [\"location\"],\n            },\n        },\n    }\n]\n\n\ndef run_chat_completion(\n    engine: JSONFFIEngine,\n    model: str,\n    prompts: List[str] = chat_completion_prompts,\n    tools: Optional[List[Dict]] = None,\n):\n    num_requests = 2\n    max_tokens = 64\n    n = 1\n    output_texts: List[List[str]] = [[\"\" for _ in range(n)] for _ in range(num_requests)]\n\n    for rid in range(num_requests):\n        print(f\"chat completion for request {rid}\")\n        for response in engine.chat.completions.create(\n            messages=[{\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": prompts[rid]}]}],\n            model=model,\n            max_tokens=max_tokens,\n            n=n,\n            request_id=str(rid),\n            tools=tools,\n        ):\n            for choice in response.choices:\n                assert choice.delta.role == \"assistant\"\n                assert isinstance(choice.delta.content, str)\n                output_texts[rid][choice.index] += choice.delta.content\n\n    # Print output.\n    print(\"Chat completion all finished\")\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n\ndef run_json_schema_function_calling(\n    engine: JSONFFIEngine,\n    model: str,\n    prompts: List[str] = function_calling_prompts,\n    tools: Optional[List[Dict]] = None,\n):\n    num_requests = 2\n    max_tokens = 64\n    n = 1\n    output_texts: List[List[str]] = [[\"\" for _ in range(n)] for _ in range(num_requests)]\n\n    class ToolCall(BaseModel):\n        name: str\n        arguments: Dict[str, str]\n\n    class Schema(BaseModel):\n        tool_calls: List[ToolCall]\n\n    schema_str = json.dumps(Schema.model_json_schema())\n    print(\"Schema str\", schema_str)\n\n    for rid in range(num_requests):\n        print(f\"chat completion for request {rid}\")\n        for response in engine.chat.completions.create(\n            messages=[\n                {\n                    \"role\": \"system\",\n                    \"content\": \"You are a function calling AI model. You are provided with function signatures within \"\n                    \"<tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make \"\n                    f\"assumptions about what values to plug into functions. Here are the available tools: <tools> {json.dumps(tools)} </tools> \"\n                    \"Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10. \"\n                    \"Calling multiple functions at once can overload the system and increase cost so call one function at a time please. \"\n                    \"If you plan to continue with analysis, always call another function. Return a valid json object (using double \"\n                    f\"quotes) in the following schema: {schema_str}\",\n                },\n                {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": prompts[rid]}]},\n            ],\n            model=model,\n            max_tokens=max_tokens,\n            n=n,\n            request_id=str(rid),\n            response_format={\"type\": \"json_object\", \"schema\": schema_str},\n        ):\n            for choice in response.choices:\n                assert choice.delta.role == \"assistant\"\n                assert isinstance(choice.delta.content, str)\n                output_texts[rid][choice.index] += choice.delta.content\n\n    # Print output.\n    print(\"Chat completion all finished\")\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q4f16_1-MLC\")\ndef test_chat_completion(model):\n    # Create engine.\n    engine = JSONFFIEngine(model)\n\n    run_chat_completion(engine, model)\n\n    # Test malformed requests.\n    for response in engine._raw_chat_completion(\n        \"malformed_string\", include_usage=False, request_id=\"123\"\n    ):\n        assert len(response.choices) == 1\n        assert response.choices[0].finish_reason == \"error\"\n\n    engine.terminate()\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q4f16_1-MLC\")\ndef test_reload_reset_unload(model):\n    # Create engine.\n    engine = JSONFFIEngine(model)\n\n    # Run chat completion before and after reload/reset.\n    run_chat_completion(engine, model)\n    engine._test_reload()\n    run_chat_completion(engine, model)\n    engine._test_reset()\n    run_chat_completion(engine, model)\n    engine._test_unload()\n\n    engine.terminate()\n\n\n@require_test_model(\"Hermes-2-Pro-Mistral-7B-q4f16_1-MLC\")\ndef test_json_schema_with_system_prompt(model):\n    engine = JSONFFIEngine(model)\n\n    # run function calling\n    run_json_schema_function_calling(engine, model, function_calling_prompts, tools)\n\n    engine.terminate()\n\n\nif __name__ == \"__main__\":\n    test_chat_completion()\n    test_reload_reset_unload()\n    test_json_schema_with_system_prompt()\n"
  },
  {
    "path": "tests/python/json_ffi/test_json_ffi_engine_image.py",
    "content": "import base64\nfrom typing import Dict, List, Optional\n\nimport requests\n\nfrom mlc_llm.json_ffi import JSONFFIEngine\nfrom mlc_llm.testing import require_test_model\n\n\ndef base64_encode_image(url: str) -> str:\n    response = requests.get(url)\n    response.raise_for_status()  # Ensure we got a successful response\n    image_data = base64.b64encode(response.content)\n    image_data_str = image_data.decode(\"utf-8\")\n    data_url = f\"data:image/jpeg;base64,{image_data_str}\"\n    return data_url\n\n\nimage_prompts = [\n    [\n        {\n            \"role\": \"user\",\n            \"content\": [\n                {\n                    \"type\": \"image_url\",\n                    \"image_url\": f\"{base64_encode_image('https://llava-vl.github.io/static/images/view.jpg')}\",\n                },\n                {\"type\": \"text\", \"text\": \"What does the image represent?\"},\n            ],\n        }\n    ]\n]\n\n\ndef run_chat_completion(\n    engine: JSONFFIEngine,\n    model: str,\n    prompts: List[List[Dict]] = image_prompts,\n    tools: Optional[List[Dict]] = None,\n):\n    num_requests = 1\n    max_tokens = 64\n    n = 1\n    output_texts: List[List[str]] = [[\"\" for _ in range(n)] for _ in range(num_requests)]\n\n    for rid in range(num_requests):\n        print(f\"chat completion for request {rid}\")\n        for response in engine.chat.completions.create(\n            messages=prompts[rid],\n            model=model,\n            max_tokens=max_tokens,\n            n=n,\n            request_id=str(rid),\n            tools=tools,\n        ):\n            for choice in response.choices:\n                assert choice.delta.role == \"assistant\"\n                assert isinstance(choice.delta.content[0], Dict)\n                assert choice.delta.content[0][\"type\"] == \"text\"\n                output_texts[rid][choice.index] += choice.delta.content[0][\"text\"]\n\n    # Print output.\n    print(\"Chat completion all finished\")\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n\n@require_test_model(\"llava-1.5-7b-hf-q4f16_1-MLC\")\ndef test_chat_completion():\n    # Create engine.\n    engine = JSONFFIEngine(\n        model,\n        max_total_sequence_length=1024,\n    )\n\n    run_chat_completion(engine, model)\n\n    # Test malformed requests.\n    for response in engine._raw_chat_completion(\"malformed_string\", n=1, request_id=\"123\"):\n        assert len(response.choices) == 1\n        assert response.choices[0].finish_reason == \"error\"\n\n    engine.terminate()\n\n\nif __name__ == \"__main__\":\n    test_chat_completion()\n"
  },
  {
    "path": "tests/python/json_ffi/test_json_ffi_engine_mock.py",
    "content": "import json\n\nimport pytest\nimport tvm\n\nfrom mlc_llm.json_ffi import JSONFFIEngine\nfrom mlc_llm.testing import require_test_model\n\n# test category \"unittest\"\npytestmark = [pytest.mark.unittest]\n\n\ndef check_error_handling(engine, expect_str, **params):\n    \"\"\"Check error handling in raw completion API\"\"\"\n    body = {\n        \"messages\": [{\"role\": \"user\", \"content\": \"hello\"}],\n        \"stream_options\": {\"include_usage\": True},\n    }\n    body.update(params)\n\n    for response in engine._raw_chat_completion(\n        json.dumps(body), include_usage=False, request_id=\"123\"\n    ):\n        if response.choices[0].finish_reason is not None:\n            break\n    if response.choices[0].finish_reason != \"error\":\n        raise RuntimeError(f\"expect the request {params} to hit an error\")\n\n    if expect_str not in response.choices[0].delta.content:\n        raise RuntimeError(\n            f\"expect '{expect_str}' in error msg, \" f\"but get '{response.choices[0].delta.content}'\"\n        )\n\n\n# NOTE: we only need tokenizers in folder\n# launch time of mock test is fast so we can put it in unittest\n@require_test_model(\"Llama-3-8B-Instruct-q4f16_1-MLC\")\ndef test_chat_completion_misuse(model: str):\n    engine = JSONFFIEngine(model, tvm.cpu(), model_lib=\"mock://echo\")\n    # Test malformed requests.\n    for response in engine._raw_chat_completion(\n        \"malformed_string\", include_usage=False, request_id=\"123\"\n    ):\n        assert len(response.choices) == 1\n        assert response.choices[0].finish_reason == \"error\"\n    # check parameters\n    check_error_handling(engine, \"should be non-negative\", temperature=-1)\n    check_error_handling(engine, \"in range [0, 1]\", top_p=100)\n    check_error_handling(engine, \"frequency_penalty\", frequency_penalty=100)\n\n\ndef check_normal_param_passing(engine):\n    json_schema = \"\"\"\n    {\"properties\": {\"result\": {\"items\": {\"type\": \"Integer\"}, \"title\": \"Result\", \"type\": \"array\"}},\n      \"required\": [\"result\"], \"title\": \"Output\", \"type\": \"object\"}\n    \"\"\"\n    param_dict = {\n        \"top_p\": 0.6,\n        \"temperature\": 0.8,\n        \"frequency_penalty\": 0.1,\n        \"presence_penalty\": 0.1,\n    }\n    usage = None\n    for response in engine.chat.completions.create(\n        messages=[{\"role\": \"user\", \"content\": \"hello\"}],\n        stream=True,\n        stream_options={\"include_usage\": True},\n        response_format={\"type\": \"json_object\", \"schema\": json_schema},\n        **param_dict,  # type: ignore\n    ):\n        if response.usage is not None:\n            usage = response.usage\n\n    # echo mock will echo back the generation config\n    for k, v in param_dict.items():\n        assert usage.extra[k] == v, f\"{k} mismatch\"\n    assert \"response_format\" in usage.extra\n    assert usage.extra[\"response_format\"][\"type\"] == \"json_object\"\n    assert \"schema\" in usage.extra[\"response_format\"]\n\n\ndef check_n_generation(engine):\n    hit_set = set()\n    for response in engine.chat.completions.create(\n        messages=[{\"role\": \"user\", \"content\": \"hello\"}],\n        stream=True,\n        stream_options={\"include_usage\": True},\n        n=3,\n    ):\n        for choice in response.choices:\n            hit_set.add(choice.index)\n    for i in range(3):\n        assert i in hit_set, f\"{i} not in n generation\"\n\n\n@require_test_model(\"Llama-3-8B-Instruct-q4f16_1-MLC\")\ndef test_chat_completion_api(model: str):\n    engine = JSONFFIEngine(model, tvm.cpu(), model_lib=\"mock://echo\")\n    check_normal_param_passing(engine)\n    check_n_generation(engine)\n\n\nif __name__ == \"__main__\":\n    test_chat_completion_api()\n    test_chat_completion_misuse()\n"
  },
  {
    "path": "tests/python/loader/test_awq.py",
    "content": "# pylint: disable=missing-docstring\nfrom pathlib import Path\nfrom typing import Union\n\nimport pytest\nimport tvm\n\nfrom mlc_llm.loader import HuggingFaceLoader\nfrom mlc_llm.model import MODEL_PRESETS, MODELS\nfrom mlc_llm.quantization import QUANTIZATION\nfrom mlc_llm.support import logging, tqdm\n\nlogging.enable_logging()\n\n\n@pytest.mark.parametrize(\n    \"param_path\",\n    [\n        \"./dist/models/llama-2-7b-w4-g128-awq.pt\",\n        \"./dist/models/Llama-2-7B-AWQ/model.safetensors\",\n    ],\n)\ndef test_load_llama(param_path: Union[str, Path]):\n    path_params = Path(param_path)\n\n    model = MODELS[\"llama\"]\n    quantization = QUANTIZATION[\"q4f16_awq\"]\n    config = model.config.from_dict(MODEL_PRESETS[\"llama2_7b\"])\n    loader = HuggingFaceLoader(\n        path=path_params,\n        extern_param_map=model.source[\"awq\"](config, quantization),\n    )\n    with tqdm.redirect():\n        for _name, _param in loader.load(tvm.device(\"cpu\")):\n            ...\n\n\nif __name__ == \"__main__\":\n    test_load_llama(param_path=\"./dist/models/llama-2-7b-w4-g128-awq.pt\")\n    test_load_llama(param_path=\"./dist/models/Llama-2-7B-AWQ/model.safetensors\")\n"
  },
  {
    "path": "tests/python/loader/test_huggingface.py",
    "content": "# pylint: disable=missing-docstring\nfrom pathlib import Path\nfrom typing import Union\n\nimport pytest\nimport tvm\n\nfrom mlc_llm.loader import HuggingFaceLoader\nfrom mlc_llm.model import MODELS\nfrom mlc_llm.support import logging, tqdm\n\nlogging.enable_logging()\n\n\n@pytest.mark.parametrize(\n    \"base_path\",\n    [\n        \"./dist/models/Llama-2-7b-hf\",\n        \"./dist/models/Llama-2-13b-hf\",\n        \"./dist/models/Llama-2-70b-hf\",\n    ],\n)\ndef test_load_torch_llama(base_path: Union[str, Path]):\n    base_path = Path(base_path)\n    path_config = base_path / \"config.json\"\n    path_params = base_path / \"pytorch_model.bin.index.json\"\n\n    model = MODELS[\"llama\"]\n    config = model.config.from_file(path_config)\n    loader = HuggingFaceLoader(\n        path=path_params,\n        extern_param_map=model.source[\"huggingface-torch\"](config, None),\n    )\n    with tqdm.redirect():\n        for _name, _param in loader.load(device=tvm.device(\"cpu\")):\n            return  # To reduce the time of the test\n\n\n@pytest.mark.parametrize(\n    \"base_path\",\n    [\n        \"./dist/models/Llama-2-7b-hf\",\n        \"./dist/models/Llama-2-13b-hf\",\n        \"./dist/models/Llama-2-70b-hf\",\n    ],\n)\ndef test_load_safetensor_llama(base_path: Union[str, Path]):\n    base_path = Path(base_path)\n    path_config = base_path / \"config.json\"\n    path_params = base_path / \"model.safetensors.index.json\"\n\n    model = MODELS[\"llama\"]\n    config = model.config.from_file(path_config)\n    loader = HuggingFaceLoader(\n        path=path_params,\n        extern_param_map=model.source[\"huggingface-safetensor\"](config, None),\n    )\n    with tqdm.redirect():\n        for _name, _param in loader.load(device=tvm.device(\"cpu\")):\n            return  # To reduce the time of the test\n\n\nif __name__ == \"__main__\":\n    test_load_torch_llama(base_path=\"./dist/models/Llama-2-7b-hf\")\n    test_load_torch_llama(base_path=\"./dist/models/Llama-2-13b-hf\")\n    test_load_torch_llama(base_path=\"./dist/models/Llama-2-70b-hf\")\n    test_load_safetensor_llama(base_path=\"./dist/models/Llama-2-7b-hf\")\n    test_load_safetensor_llama(base_path=\"./dist/models/Llama-2-13b-hf\")\n    test_load_safetensor_llama(base_path=\"./dist/models/Llama-2-70b-hf\")\n"
  },
  {
    "path": "tests/python/model/test_gemma3.py",
    "content": "# pylint: disable=invalid-name,missing-docstring\n\"\"\"Unit tests for Gemma3 model architecture.\"\"\"\n\nimport pytest\n\nfrom mlc_llm.model import MODEL_PRESETS, MODELS\n\n\ndef test_gemma3_model_registered():\n    \"\"\"Verify Gemma3 model is in the registry.\"\"\"\n    assert \"gemma3\" in MODELS, \"gemma3 should be registered in MODELS\"\n\n\n@pytest.mark.parametrize(\n    \"model_name\",\n    [\n        \"gemma3_2b\",\n        \"gemma3_9b\",\n    ],\n)\ndef test_gemma3_creation(model_name: str):\n    \"\"\"Test Gemma3 model creation and export to TVM IR.\n\n    Verifies:\n    - Config can be loaded from preset\n    - Model instance can be created\n    - Model exports to TVM IR successfully\n    - Named parameters are extracted\n    \"\"\"\n    model_info = MODELS[\"gemma3\"]\n    config = model_info.config.from_dict(MODEL_PRESETS[model_name])\n    model = model_info.model(config)\n    mod, named_params = model.export_tvm(\n        spec=model.get_default_spec(),  # type: ignore\n    )\n\n    # Verify export succeeded\n    assert mod is not None\n    assert len(named_params) > 0\n\n    # Optional: show module structure\n    mod.show(black_format=False)\n\n    # Print parameters for debugging\n    for name, param in named_params:\n        print(name, param.shape, param.dtype)\n\n\ndef test_gemma3_config_validation():\n    \"\"\"Test Gemma3 configuration has required fields.\"\"\"\n    model_info = MODELS[\"gemma3\"]\n    config = model_info.config.from_dict(MODEL_PRESETS[\"gemma3_2b\"])\n\n    # Check required config parameters\n    assert hasattr(config, \"hidden_size\") and config.hidden_size > 0\n    assert hasattr(config, \"num_hidden_layers\") and config.num_hidden_layers > 0\n    assert hasattr(config, \"num_attention_heads\") and config.num_attention_heads > 0\n    assert hasattr(config, \"vocab_size\") and config.vocab_size > 0\n\n    print(\n        f\"Gemma3 Config: hidden_size={config.hidden_size}, \"\n        f\"layers={config.num_hidden_layers}, \"\n        f\"heads={config.num_attention_heads}, \"\n        f\"vocab={config.vocab_size}\"\n    )\n\n\nif __name__ == \"__main__\":\n    # Allow running tests directly\n    test_gemma3_creation(\"gemma3_2b\")\n    test_gemma3_creation(\"gemma3_9b\")\n"
  },
  {
    "path": "tests/python/model/test_gpt2.py",
    "content": "# pylint: disable=invalid-name,missing-docstring\nimport pytest\n\nfrom mlc_llm.model import MODEL_PRESETS, MODELS\n\n\n@pytest.mark.parametrize(\"model_name\", [\"gpt2\"])\ndef test_gpt2_creation(model_name: str):\n    model_info = MODELS[\"gpt2\"]\n    config = model_info.config.from_dict(MODEL_PRESETS[model_name])\n    model = model_info.model(config)\n    mod, named_params = model.export_tvm(\n        spec=model.get_default_spec(),  # type: ignore\n    )\n    mod.show(black_format=False)\n    for name, param in named_params:\n        print(name, param.shape, param.dtype)\n\n\nif __name__ == \"__main__\":\n    test_gpt2_creation(\"gpt2\")\n"
  },
  {
    "path": "tests/python/model/test_gptNeox.py",
    "content": "# pylint: disable=invalid-name,missing-docstring\nimport pytest\n\nfrom mlc_llm.model import MODEL_PRESETS, MODELS\n\n\n@pytest.mark.parametrize(\"model_name\", [\"redpajama_3b_v1\"])\ndef test_mistral_creation(model_name: str):\n    model_info = MODELS[\"gpt_neox\"]\n    config = model_info.config.from_dict(MODEL_PRESETS[model_name])\n    model = model_info.model(config)\n    mod, named_params = model.export_tvm(\n        spec=model.get_default_spec(),  # type: ignore\n    )\n    mod.show(black_format=False)\n    for name, param in named_params:\n        print(name, param.shape, param.dtype)\n\n\nif __name__ == \"__main__\":\n    test_mistral_creation(\"redpajama_3b_v1\")\n"
  },
  {
    "path": "tests/python/model/test_kv_cache.py",
    "content": "# pylint: disable=line-too-long,missing-docstring\nimport tvm\nfrom tvm import tir\nfrom tvm.relax.frontend.nn import core, modules, spec\nfrom tvm.script import ir as I\nfrom tvm.script import relax as R\nfrom tvm.script import tir as T\n\nfrom mlc_llm.nn.kv_cache import PagedKVCache, RopeMode\n\n# mypy: disable-error-code=\"attr-defined\"\n# pylint: disable=invalid-name,unused-argument,too-many-locals,too-many-statements\n\n\ndef test_nn_module_paged_kv_cache():\n    # fmt: off\n    @I.ir_module\n    class Module:\n        @R.function\n        def create_paged_kv_cache(\n            max_batch_size: R.Shape([\"max_batch_size_1\"]),  # type: ignore\n            max_total_seq_len: R.Shape([\"max_total_seq_len_1\"]),  # type: ignore\n            prefill_chunk_size: R.Shape([\"prefill_chunk_size_1\"]),  # type: ignore\n            page_size: R.Shape([\"page_size_1\"]),  # type: ignore\n            support_sliding_window: R.Shape([\"support_sliding_window_1\"]),  # type: ignore\n        ) -> R.Object:\n            max_batch_size_1 = T.int64()\n            max_total_seq_len_1 = T.int64()\n            prefill_chunk_size_1 = T.int64()\n            page_size_1 = T.int64()\n            support_sliding_window_1 = T.int64()\n            R.func_attr({\"num_input\": 5})\n            with R.dataflow():\n                paged_kv_cache: R.Object = R.call_pure_packed(\"mlc.create_paged_kv_cache_generic\", R.shape([max_batch_size_1, max_total_seq_len_1, prefill_chunk_size_1, page_size_1, support_sliding_window_1]), R.prim_value(32), R.prim_value(32), R.prim_value(32), R.prim_value(128), R.prim_value(1), R.prim_value(1), R.prim_value(10000), R.prim_value(128), R.dtype(\"float16\"), sinfo_args=(R.Object,))\n                gv1: R.Object = paged_kv_cache\n                R.output(gv1)\n            return gv1\n\n        @R.function\n        def forward(\n            cache: R.Object, qkv: R.Tensor((1, 100, 96, 128), dtype=\"float16\")  # type: ignore\n        ) -> R.Tensor((1, 100, 32, 128), dtype=\"float16\"):  # type: ignore\n            R.func_attr({\"num_input\": 2})\n            with R.dataflow():\n                reshape: R.Tensor((100, 96, 128), dtype=\"float16\") = R.reshape(  # type: ignore\n                    qkv, R.shape([100, 96, 128])\n                )\n                lv = R.call_dps_packed(\n                    \"vm.builtin.attention_kv_cache_attention_with_fused_qkv\",\n                    (cache, R.prim_value(0), R.prim_value(T.float32(1)), reshape),\n                    out_sinfo=R.Tensor((100, 32, 128), dtype=\"float16\"),\n                )\n                reshape1: R.Tensor((1, 100, 32, 128), dtype=\"float16\") = R.reshape(  # type: ignore\n                    lv, R.shape([1, 100, 32, 128])\n                )\n                gv: R.Tensor((1, 100, 32, 128), dtype=\"float16\") = reshape1  # type: ignore\n                R.output(gv)\n            return gv\n    # fmt: on\n\n    class PagedKVCacheTest(modules.Module):\n        def forward(\n            self,\n            cache: PagedKVCache,\n            qkv: core.Tensor,\n        ) -> core.Tensor:\n            return cache.attention_with_fused_qkv(0, qkv, num_qo_heads=32, sm_scale=128**-0.5)\n\n        def create_paged_kv_cache(\n            self,\n            max_batch_size: tir.Var,\n            max_total_seq_len: tir.Var,\n            prefill_chunk_size: tir.Var,\n            page_size: tir.Var,\n            support_sliding_window: tir.Var,\n        ) -> PagedKVCache:\n            return PagedKVCache.create_generic(\n                attn_kind=\"mha\",\n                max_batch_size=max_batch_size,\n                max_total_seq_len=max_total_seq_len,\n                prefill_chunk_size=prefill_chunk_size,\n                page_size=page_size,\n                support_sliding_window=support_sliding_window,\n                num_hidden_layers=32,\n                num_attention_heads=32,\n                num_key_value_heads=32,\n                qk_head_dim=128,\n                v_head_dim=128,\n                rope_mode=RopeMode.NORMAL,\n                rope_scale=1,\n                rope_theta=10000,\n                rotary_dim=128,\n                dtype=\"float16\",\n            )\n\n    export_results = PagedKVCacheTest().export_tvm(\n        spec={\n            \"forward\": {\n                \"cache\": spec.Object(object_type=PagedKVCache),\n                \"qkv\": spec.Tensor((1, 100, 96, 128), \"float16\"),\n            },\n            \"create_paged_kv_cache\": {\n                \"max_batch_size\": int,\n                \"max_total_seq_len\": int,\n                \"prefill_chunk_size\": int,\n                \"page_size\": int,\n                \"support_sliding_window\": int,\n            },\n        },\n    )\n    tvm_mod = export_results[0]\n    tvm.ir.assert_structural_equal(tvm_mod, Module, True)\n\n\nif __name__ == \"__main__\":\n    test_nn_module_paged_kv_cache()\n"
  },
  {
    "path": "tests/python/model/test_llama.py",
    "content": "# pylint: disable=invalid-name,missing-docstring\nimport pytest\n\nfrom mlc_llm.model import MODEL_PRESETS, MODELS\n\n\n@pytest.mark.parametrize(\n    \"model_name\", [\"llama2_7b\", \"llama2_13b\", \"llama2_70b\", \"tinyllama_1b_chat_v1.0\"]\n)\ndef test_llama2_creation(model_name: str):\n    model_info = MODELS[\"llama\"]\n    config = model_info.config.from_dict(MODEL_PRESETS[model_name])\n    model = model_info.model(config)\n    mod, named_params = model.export_tvm(\n        spec=model.get_default_spec(),  # type: ignore\n    )\n    mod.show(black_format=False)\n    for name, param in named_params:\n        print(name, param.shape, param.dtype)\n\n\nif __name__ == \"__main__\":\n    test_llama2_creation(\"llama2_7b\")\n    test_llama2_creation(\"llama2_13b\")\n    test_llama2_creation(\"llama2_70b\")\n    test_llama2_creation(\"tinyllama_1b_chat_v1\")\n"
  },
  {
    "path": "tests/python/model/test_llama_quantization.py",
    "content": "# pylint: disable=invalid-name,missing-docstring\nimport pytest\n\nfrom mlc_llm.model import MODEL_PRESETS, MODELS\nfrom mlc_llm.quantization import QUANTIZATION\nfrom mlc_llm.quantization.group_quantization import (\n    GroupQuantizeEmbedding,\n    GroupQuantizeLinear,\n)\n\n\n@pytest.mark.parametrize(\n    \"model_name\",\n    [\"llama2_7b\", \"llama2_13b\", \"llama2_70b\"],\n)\n@pytest.mark.parametrize(\n    \"quant_name\",\n    [\"q3f16_1\", \"q4f16_1\", \"q4f32_1\"],\n)\ndef test_llama2_group_quantization(model_name: str, quant_name: str):\n    model_info = MODELS[\"llama\"]\n    config = model_info.config.from_dict(MODEL_PRESETS[model_name])\n    model, quant_map = model_info.quantize[\"group-quant\"](config, QUANTIZATION[quant_name])\n    assert \"model.embed_tokens.weight\" in quant_map.param_map\n    assert isinstance(\n        model.model.embed_tokens,  # type: ignore[attr-defined]\n        GroupQuantizeEmbedding,\n    )\n    assert \"lm_head.weight\" in quant_map.param_map\n    assert isinstance(model.lm_head, GroupQuantizeLinear)  # type: ignore[attr-defined]\n    for i in range(config.num_hidden_layers):\n        assert f\"model.layers.{i}.self_attn.qkv_proj.weight\" in quant_map.param_map\n        assert isinstance(\n            model.model.layers[i].self_attn.qkv_proj,  # type: ignore[attr-defined]\n            GroupQuantizeLinear,\n        )\n        assert f\"model.layers.{i}.self_attn.o_proj.weight\" in quant_map.param_map\n        assert isinstance(\n            model.model.layers[i].self_attn.o_proj,  # type: ignore[attr-defined]\n            GroupQuantizeLinear,\n        )\n        assert f\"model.layers.{i}.mlp.gate_up_proj.weight\" in quant_map.param_map\n        assert isinstance(\n            model.model.layers[i].mlp.gate_up_proj,  # type: ignore[attr-defined]\n            GroupQuantizeLinear,\n        )\n        assert f\"model.layers.{i}.mlp.down_proj.weight\" in quant_map.param_map\n        assert isinstance(\n            model.model.layers[i].mlp.down_proj,  # type: ignore[attr-defined]\n            GroupQuantizeLinear,\n        )\n\n\n@pytest.mark.parametrize(\n    \"model_name\",\n    [\"llama2_7b\", \"llama2_13b\", \"llama2_70b\"],\n)\n@pytest.mark.parametrize(\n    \"quant_name\",\n    [\"q0f16\", \"q0f32\"],\n)\ndef test_llama2_no_quantization(model_name: str, quant_name: str):\n    model_info = MODELS[\"llama\"]\n    config = model_info.config.from_dict(MODEL_PRESETS[model_name])\n    _, quant_map = model_info.quantize[\"no-quant\"](config, QUANTIZATION[quant_name])\n    assert len(quant_map.param_map) == 0\n    assert len(quant_map.map_func) == 0\n\n\nif __name__ == \"__main__\":\n    test_llama2_group_quantization(\"llama2_7b\", \"q4f16_1\")\n    test_llama2_group_quantization(\"llama2_13b\", \"q4f16_1\")\n    test_llama2_group_quantization(\"llama2_70b\", \"q4f16_1\")\n"
  },
  {
    "path": "tests/python/model/test_mistral.py",
    "content": "# pylint: disable=invalid-name,missing-docstring\nimport pytest\n\nfrom mlc_llm.model import MODEL_PRESETS, MODELS\n\n\n@pytest.mark.parametrize(\"model_name\", [\"mistral_7b\"])\ndef test_mistral_creation(model_name: str):\n    model_info = MODELS[\"mistral\"]\n    config = model_info.config.from_dict(MODEL_PRESETS[model_name])\n    model = model_info.model(config)\n    mod, named_params = model.export_tvm(\n        spec=model.get_default_spec(),  # type: ignore\n    )\n    mod.show(black_format=False)\n    for name, param in named_params:\n        print(name, param.shape, param.dtype)\n\n\nif __name__ == \"__main__\":\n    test_mistral_creation(\"mistral_7b\")\n"
  },
  {
    "path": "tests/python/model/test_phi.py",
    "content": "# pylint: disable=invalid-name,missing-docstring\nimport pytest\n\nfrom mlc_llm.model import MODEL_PRESETS, MODELS\n\n\n@pytest.mark.parametrize(\"model_name\", [\"phi-1_5\", \"phi-2\"])\ndef test_phi_creation(model_name: str):\n    model_info = MODELS[\"phi-msft\"]\n    config = model_info.config.from_dict(MODEL_PRESETS[model_name])\n    model = model_info.model(config)\n    mod, named_params = model.export_tvm(\n        spec=model.get_default_spec(),  # type: ignore\n    )\n    mod.show(black_format=False)\n    for name, param in named_params:\n        print(name, param.shape, param.dtype)\n\n\nif __name__ == \"__main__\":\n    test_phi_creation(\"phi-1_5\")\n    test_phi_creation(\"phi-2\")\n"
  },
  {
    "path": "tests/python/model/test_qwen3_embedding.py",
    "content": "# pylint: disable=invalid-name,missing-docstring\nimport json\nimport os\n\nimport numpy as np\nimport pytest\nimport torch\nimport tvm\nfrom safetensors import safe_open\nfrom transformers import AutoModel, AutoTokenizer\nfrom tvm import relax\nfrom tvm.contrib import tvmjs\nfrom tvm.runtime import ShapeTuple\nfrom tvm.runtime.vm import VirtualMachine\n\nMLC_QWEN3_EMB_HF_DIR = os.environ.get(\"MLC_QWEN3_EMB_HF_DIR\")\nMLC_QWEN3_EMB_MODEL_DIR = os.environ.get(\"MLC_QWEN3_EMB_MODEL_DIR\")\nMLC_QWEN3_EMB_MODEL_LIB = os.environ.get(\"MLC_QWEN3_EMB_MODEL_LIB\")\nMLC_QWEN3_EMB_DEVICE = os.environ.get(\"MLC_QWEN3_EMB_DEVICE\", \"cuda\")\n\n_skip = not all([MLC_QWEN3_EMB_HF_DIR, MLC_QWEN3_EMB_MODEL_DIR, MLC_QWEN3_EMB_MODEL_LIB])\n_skip_reason = (\n    \"Set MLC_QWEN3_EMB_HF_DIR, MLC_QWEN3_EMB_MODEL_DIR, \" \"MLC_QWEN3_EMB_MODEL_LIB to run this test\"\n)\n\nTEST_TEXTS = [\n    \"What is machine learning?\",\n    \"CMU is Carnegie Mellon University\",\n    \"机器学习是人工智能的一个分支\",\n    \"量子コンピュータの基本原理を説明してください\",\n    \"머신러닝은 인공지능의 한 분야입니다.\",\n    (\n        \"Instruct: Given a web search query, retrieve relevant passages \"\n        \"that answer the query\\nQuery: What is the capital of China?\"\n    ),\n    (\n        \"The Transformer architecture, introduced in the paper Attention Is All You Need, \"\n        \"revolutionized natural language processing by replacing recurrent layers with \"\n        \"self-attention mechanisms. This allows the model to process all positions in a \"\n        \"sequence simultaneously rather than sequentially, leading to significant improvements \"\n        \"in both training efficiency and the ability to capture long-range dependencies. \"\n        \"The key innovation is the multi-head attention mechanism, which allows the model \"\n        \"to jointly attend to information from different representation subspaces at \"\n        \"different positions.\"\n    ),\n    \"Hello\",\n    \"def fibonacci(n): return n if n <= 1 else fibonacci(n-1) + fibonacci(n-2)\",\n]\n\n\ndef _load_embed_weight(hf_dir):\n    safetensor_files = [f for f in os.listdir(hf_dir) if f.endswith(\".safetensors\")]\n    for sf in safetensor_files:\n        with safe_open(os.path.join(hf_dir, sf), framework=\"pt\", device=\"cpu\") as f:\n            if \"embed_tokens.weight\" in f.keys():\n                return f.get_tensor(\"embed_tokens.weight\")\n    raise FileNotFoundError(f\"embed_tokens.weight not found in {hf_dir}\")\n\n\ndef _hf_logits(text, tokenizer, hf_model, embed_weight):\n    inputs = tokenizer(text, return_tensors=\"pt\")\n    with torch.no_grad():\n        hidden = hf_model(**inputs).last_hidden_state.float()\n        logits = hidden @ embed_weight.float().T\n    return logits[0, -1, :].numpy()\n\n\ndef _mlc_logits(text, tokenizer, mlc_module, params, metadata, dev, embed_weight):\n    input_ids = tokenizer(text, return_tensors=\"pt\")[\"input_ids\"][0].numpy().astype(np.int32)\n    seq_len = len(input_ids)\n\n    embed_func = mlc_module[\"embed\"]\n    prefill_func = mlc_module[\"prefill_to_last_hidden_states\"]\n\n    if mlc_module.implements_function(\"create_flashinfer_paged_kv_cache\"):\n        create_kv = mlc_module[\"create_flashinfer_paged_kv_cache\"]\n    elif mlc_module.implements_function(\"create_tir_paged_kv_cache\"):\n        create_kv = mlc_module[\"create_tir_paged_kv_cache\"]\n    else:\n        raise RuntimeError(\"Cannot find KV cache creation function\")\n\n    sliding_window = metadata.get(\"sliding_window_size\", -1)\n    context_window = metadata.get(\"context_window_size\", 32768)\n    prefill_chunk = metadata.get(\"prefill_chunk_size\", 2048)\n    max_seq_len = sliding_window if context_window == -1 else context_window\n\n    kv_cache = create_kv(\n        ShapeTuple([1]),\n        ShapeTuple([max_seq_len]),\n        ShapeTuple([prefill_chunk]),\n        ShapeTuple([16]),\n        ShapeTuple([int(sliding_window != -1)]),\n    )\n\n    nd_view = tvm.get_global_func(\"vm.builtin.reshape\")\n    add_sequence = tvm.get_global_func(\"vm.builtin.kv_state_add_sequence\")\n    begin_forward = tvm.get_global_func(\"vm.builtin.kv_state_begin_forward\")\n    end_forward = tvm.get_global_func(\"vm.builtin.kv_state_end_forward\")\n\n    tokens_tvm = tvm.runtime.tensor(input_ids, device=dev)\n    embedding = embed_func(tokens_tvm, params)\n    embedding = nd_view(embedding, ShapeTuple([1, seq_len, embedding.shape[-1]]))\n\n    add_sequence(kv_cache, 0)\n    begin_forward(kv_cache, ShapeTuple([0]), ShapeTuple([seq_len]))\n    hidden_states, _ = prefill_func(embedding, kv_cache, params)\n    end_forward(kv_cache)\n\n    # Compute logits from hidden states using embed_tokens weight (tie_word_embeddings)\n    hidden = hidden_states.numpy().astype(np.float32)\n    logits = hidden @ embed_weight.float().numpy().T\n    return logits[0, -1, :]\n\n\n@pytest.mark.skipif(_skip, reason=_skip_reason)\ndef test_mlc_hf_logit_match():\n    tokenizer = AutoTokenizer.from_pretrained(MLC_QWEN3_EMB_HF_DIR, padding_side=\"left\")\n    hf_model = AutoModel.from_pretrained(MLC_QWEN3_EMB_HF_DIR)\n    embed_weight = _load_embed_weight(MLC_QWEN3_EMB_HF_DIR)\n\n    dev = tvm.runtime.device(MLC_QWEN3_EMB_DEVICE, 0)\n    ex = tvm.runtime.load_module(MLC_QWEN3_EMB_MODEL_LIB)\n    vm = relax.VirtualMachine(ex, dev)\n    mlc_module = vm.module\n\n    metadata = json.loads(VirtualMachine(ex, tvm.runtime.device(\"cpu\"))[\"_metadata\"]())\n    params_dict, _ = tvmjs.load_tensor_cache(MLC_QWEN3_EMB_MODEL_DIR, dev)\n    param_names = [p[\"name\"] for p in metadata[\"params\"]]\n    params = [params_dict[name] for name in param_names]\n\n    for text in TEST_TEXTS:\n        hf = _hf_logits(text, tokenizer, hf_model, embed_weight)\n        mlc = _mlc_logits(text, tokenizer, mlc_module, params, metadata, dev, embed_weight)\n\n        cos_sim = np.dot(hf, mlc) / (np.linalg.norm(hf) * np.linalg.norm(mlc))\n        assert cos_sim > 0.99, f\"[{text[:30]}] Cosine similarity {cos_sim:.6f} below 0.99\"\n\n        max_diff = np.max(np.abs(hf - mlc))\n        assert max_diff < 1.0, f\"[{text[:30]}] Max absolute diff {max_diff:.6e} exceeds 1.0\"\n\n        hf_top10 = set(np.argsort(hf)[-10:])\n        mlc_top10 = set(np.argsort(mlc)[-10:])\n        overlap = len(hf_top10 & mlc_top10)\n        assert overlap >= 7, f\"[{text[:30]}] Top-10 overlap {overlap}/10 below 7\"\n\n\nif __name__ == \"__main__\":\n    test_mlc_hf_logit_match()\n"
  },
  {
    "path": "tests/python/op/test_batch_spec_verify.py",
    "content": "import numpy as np\nimport pytest\nimport tvm\nimport tvm.testing\n\nfrom mlc_llm.op.batch_spec_verify import batch_spec_verify\n\n# test category \"op_correctness\"\npytestmark = [pytest.mark.op_correctness]\n\n\n@pytest.mark.parametrize(\"nbatch\", [32, 64])\n@pytest.mark.parametrize(\"vocab\", [3, 32, 64, 32000, 33, 65, 32001, 128000])\n@pytest.mark.parametrize(\"plist\", [[0.5, 0.5], [1, 0], [0, 1]])\ndef test_batch_spec_verify(nbatch, vocab, plist):\n    def numpy_reference(\n        draft_probs,\n        draft_tokens,\n        model_probs,\n        token_tree_first_child,\n        token_tree_next_sibling,\n        uniform_samples,\n        token_tree_parent_ptr,\n    ):\n        nbatch = token_tree_parent_ptr.shape[0]\n        for b in range(nbatch):\n            parent_ptr = token_tree_parent_ptr[b]\n            child_ptr = token_tree_first_child[parent_ptr]\n            while child_ptr != -1:\n                child_token = draft_tokens[child_ptr]\n                p_child = model_probs[parent_ptr, child_token]\n                q_child = draft_probs[child_ptr, child_token]\n                uniform_sample = uniform_samples[child_ptr]\n                if p_child / q_child >= uniform_sample:\n                    parent_ptr = child_ptr\n                    child_ptr = token_tree_first_child[child_ptr]\n                else:\n                    model_probs[parent_ptr, :] = np.maximum(\n                        model_probs[parent_ptr, :] - draft_probs[child_ptr, :], 0.0\n                    )\n                    psum = np.sum(model_probs[parent_ptr, :])\n                    model_probs[parent_ptr, :] /= psum\n                    child_ptr = token_tree_next_sibling[child_ptr]\n            token_tree_parent_ptr[b] = parent_ptr\n\n    np.random.seed(0)\n\n    def gen_chain(num_nodes, base):\n        token_tree_first_child = list()\n        token_tree_next_sibling = list()\n        for i in range(num_nodes):\n            token_tree_first_child.append(base + i + 1 if i + 1 < num_nodes else -1)\n            token_tree_next_sibling.append(-1)\n        return token_tree_first_child, token_tree_next_sibling, base, base + 1\n\n    def gen_full_binary_tree(height, base):\n        token_tree_first_child = list()\n        token_tree_next_sibling = list()\n        num_nodes = 2**height - 1\n        for i in range(num_nodes):\n            token_tree_first_child.append(base + i * 2 + 1 if i * 2 + 1 < num_nodes else -1)\n            token_tree_next_sibling.append(base + i * 2 + 2 if i * 2 + 2 < num_nodes else -1)\n        return token_tree_first_child, token_tree_next_sibling, base, base + 1\n\n    ### Inputs\n    num_nodes = 0\n    token_tree_first_child = list()\n    token_tree_next_sibling = list()\n    token_tree_parent_ptr = list()\n\n    for _ in range(nbatch):\n        choice = np.random.choice(2, 1, p=plist)\n        if choice == 0:\n            nodes_batch = np.random.randint(3, 32)\n            res = gen_chain(nodes_batch, num_nodes)\n            num_nodes += nodes_batch\n        else:\n            height = np.random.randint(3, 5)\n            res = gen_full_binary_tree(height, num_nodes)\n            num_nodes += 2**height - 1\n        token_tree_first_child.extend(res[0])\n        token_tree_next_sibling.extend(res[1])\n        token_tree_parent_ptr.append(res[2])\n\n    token_tree_first_child = np.array(token_tree_first_child).astype(\"int32\")\n    token_tree_next_sibling = np.array(token_tree_next_sibling).astype(\"int32\")\n    token_tree_parent_ptr = np.array(token_tree_parent_ptr).astype(\"int32\")\n\n    draft_probs = np.random.rand(num_nodes, vocab).astype(\"float32\")\n    draft_probs /= np.sum(draft_probs, axis=1, keepdims=True)\n    draft_tokens = np.random.randint(0, vocab, num_nodes).astype(\"int32\")\n    model_probs = np.random.rand(num_nodes, vocab).astype(\"float32\")\n    model_probs /= np.sum(model_probs, axis=1, keepdims=True)\n    uniform_samples = np.random.rand(num_nodes).astype(\"float32\")\n\n    ### TVM Inputs\n    dev = tvm.cuda(0)\n    draft_probs_tvm = tvm.runtime.tensor(draft_probs, dev)\n    draft_tokens_tvm = tvm.runtime.tensor(draft_tokens, dev)\n    model_probs_tvm = tvm.runtime.tensor(model_probs, dev)\n    token_tree_first_child_tvm = tvm.runtime.tensor(token_tree_first_child, dev)\n    token_tree_next_sibling_tvm = tvm.runtime.tensor(token_tree_next_sibling, dev)\n    uniform_samples_tvm = tvm.runtime.tensor(uniform_samples, dev)\n    token_tree_parent_ptr_tvm = tvm.runtime.tensor(token_tree_parent_ptr, dev)\n\n    # print(\"draft_probs\", draft_probs)\n    # print(\"draft_tokens\", draft_tokens)\n    # print(\"model_probs\", model_probs)\n    # print(\"token_tree_first_child\", token_tree_first_child)\n    # print(\"token_tree_next_sibling\", token_tree_next_sibling)\n    # print(\"uniform_samples\", uniform_samples)\n    # print(\"token_tree_parent_ptr\", token_tree_parent_ptr)\n\n    ### Numpy reference\n    numpy_reference(\n        draft_probs,\n        draft_tokens,\n        model_probs,\n        token_tree_first_child,\n        token_tree_next_sibling,\n        uniform_samples,\n        token_tree_parent_ptr,\n    )\n    # print(\"model_probs\", model_probs)\n    # print(\"token_tree_parent_ptr\", token_tree_parent_ptr)\n\n    ### TVM\n    kernel = batch_spec_verify(vocab)\n    mod = tvm.build(kernel, target=\"cuda\")\n    mod(\n        draft_probs_tvm,\n        draft_tokens_tvm,\n        model_probs_tvm,\n        token_tree_first_child_tvm,\n        token_tree_next_sibling_tvm,\n        uniform_samples_tvm,\n        token_tree_parent_ptr_tvm,\n    )\n    # print(\"model_probs\", model_probs_tvm.asnumpy())\n    # print(\"token_tree_parent_ptr\", token_tree_parent_ptr_tvm.asnumpy())\n\n    tvm.testing.assert_allclose(model_probs, model_probs_tvm.asnumpy())\n    tvm.testing.assert_allclose(\n        token_tree_parent_ptr, token_tree_parent_ptr_tvm.asnumpy(), rtol=0, atol=0\n    )\n\n    time_evaluator = mod.time_evaluator(mod.entry_name, dev, number=10, repeat=3)\n    print(f\"batch_size: {nbatch}, vocab_size: {vocab}, tree_structure: {plist}\")\n    print(\n        time_evaluator(\n            draft_probs_tvm,\n            draft_tokens_tvm,\n            model_probs_tvm,\n            token_tree_first_child_tvm,\n            token_tree_next_sibling_tvm,\n            uniform_samples_tvm,\n            token_tree_parent_ptr_tvm,\n        )\n    )\n\n\nif __name__ == \"__main__\":\n    tvm.testing.main()\n"
  },
  {
    "path": "tests/python/op/test_fp8_block_matmul.py",
    "content": "from itertools import product\nfrom typing import Tuple\n\nimport ml_dtypes\nimport numpy as np\nimport pytest\nimport torch\nimport tvm\nfrom tvm import relax\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import spec\nfrom tvm.s_tir import dlight as dl\n\nfrom mlc_llm.compiler_pass.dispatch_triton_kernel import DispatchTritonKernel\nfrom mlc_llm.op import batch_matmul, cutlass, moe_matmul, triton\nfrom mlc_llm.quantization.block_scale_quantization import rowwise_group_quant_fp8\n\n# test category \"op_correctness\"\npytestmark = [pytest.mark.op_correctness]\n\nblock_size = (128, 128)\nfp8_dtype = \"float8_e4m3fn\"\n\ntorch_fp8_dtype = torch.float8_e4m3fn\ntorch_device = torch.device(\"cuda\")\n\ntorch.set_grad_enabled(False)\n\n\ndef test_fp8_block_matmul_cutlass(M: int, N: int, K: int, dtype: str):\n    class TestModule(nn.Module):\n        def __init__(self):\n            pass\n\n        def cutlass_gemm(self, x: nn.Tensor, w: nn.Tensor, w_scale: nn.Tensor):\n            n, k = w.shape\n            m = x.shape[0]\n            # assert n % block_size[0] == 0\n            assert k % block_size[1] == 0\n            assert (n + block_size[0] - 1) // block_size[0] == w_scale.shape[0]\n            assert k // block_size[1] == w_scale.shape[1]\n            assert x.shape[1] == k\n            x_fp8, x_scale = rowwise_group_quant_fp8(\n                x, block_size[1], w.dtype, transpose_scale=True\n            )\n            assert x_fp8.dtype == w.dtype\n            assert x_scale.dtype == \"float32\"\n            o = cutlass.fp8_groupwise_scaled_gemm(x_fp8, x_scale, w, w_scale, block_size, x.dtype)\n            return x_fp8, x_scale, o\n\n    mod, _, ext_mods = TestModule().export_tvm(\n        spec={\n            \"cutlass_gemm\": {\n                \"x\": spec.Tensor((\"m\", K), dtype),\n                \"w\": spec.Tensor((N, K), fp8_dtype),\n                \"w_scale\": spec.Tensor(\n                    (\n                        (N + block_size[0] - 1) // block_size[0],\n                        (K + block_size[1] - 1) // block_size[1],\n                    ),\n                    \"float32\",\n                ),\n            },\n        },\n        allow_extern=True,\n    )\n    device = tvm.cuda()\n    target = tvm.target.Target.from_device(device)\n    exec = relax.build(\n        mod,\n        target=target,\n        relax_pipeline=relax.backend.cuda.get_default_pipeline(target),\n    )\n    vm = relax.VirtualMachine(exec, device)\n\n    x_torch = torch.rand(M, K, dtype=getattr(torch, dtype), device=torch_device) * 2 - 1\n    w_full_torch = torch.rand(N, K, dtype=getattr(torch, dtype), device=torch_device) * 2 - 1\n    w_torch, w_scale_torch = blockwise_quant_fp8(w_full_torch, block_size, torch_fp8_dtype)\n    x_torch, x_fp8_torch, x_scale_torch = rowwise_quant_fp8(x_torch, block_size, torch_fp8_dtype)\n    o_torch = blockwise_matmul(x_fp8_torch, x_scale_torch, w_torch, w_scale_torch, x_torch.dtype)\n    x_tvm = tvm.runtime.tensor(x_torch.view(torch.float16).cpu().numpy().view(dtype), device=device)\n    w_tvm = tvm.runtime.tensor(\n        w_torch.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device\n    )\n    w_scale_tvm = tvm.runtime.tensor(w_scale_torch.cpu().numpy(), device=device)\n    x_fp8_tvm, x_scale_tvm, o_tvm = vm[\"cutlass_gemm\"](x_tvm, w_tvm, w_scale_tvm)\n    x_fp8_tvm = x_fp8_tvm.numpy()\n    x_scale_tvm = x_scale_tvm.numpy()\n    o_tvm = o_tvm.numpy()\n\n    np.testing.assert_allclose(\n        x_fp8_tvm,\n        x_fp8_torch.view(torch.uint8).cpu().numpy().view(fp8_dtype),\n        atol=1e-1,\n        rtol=1e-1,\n    )\n    np.testing.assert_allclose(x_scale_tvm.T, x_scale_torch.cpu().numpy(), atol=1e-5, rtol=1e-5)\n    atol = 0.5\n    rtol = 1e-4\n    o_tvm_flat = o_tvm.flatten()\n    o_torch_flat = o_torch.view(torch.float16).cpu().numpy().view(dtype).flatten()\n    failed_indices = np.where(\n        np.abs(o_tvm_flat - o_torch_flat) > (atol + rtol * np.abs(o_torch_flat))\n    )[0]\n    if len(failed_indices) > 0:\n        print(f\"failed_indices: {failed_indices}, size: {len(failed_indices)}\")\n        print(f\"o_tvm_flat[failed_indices]: {o_tvm_flat[failed_indices]}\")\n        print(f\"o_torch_flat[failed_indices]: {o_torch_flat[failed_indices]}\")\n    np.testing.assert_allclose(\n        o_tvm,\n        o_torch.view(torch.float16).cpu().numpy().view(dtype),\n        atol=atol,\n        rtol=rtol,\n    )\n\n\ndef test_fp8_block_matmul_triton(M: int, N: int, K: int, dtype: str):\n    device = tvm.cuda()\n    target = tvm.target.Target.from_device(device)\n\n    class TestModule(nn.Module):\n        def __init__(self):\n            pass\n\n        def triton_gemm(self, x: nn.Tensor, w: nn.Tensor, w_scale: nn.Tensor):\n            n, k = w.shape\n            m = x.shape[0]\n            assert (n + block_size[0] - 1) // block_size[0] == w_scale.shape[0]\n            assert (k + block_size[1] - 1) // block_size[1] == w_scale.shape[1]\n            assert x.shape[1] == k\n            x_fp8, x_scale = rowwise_group_quant_fp8(\n                x, block_size[1], w.dtype, transpose_scale=False\n            )\n            assert x_fp8.dtype == w.dtype\n            assert x_scale.dtype == \"float32\"\n            o = triton.fp8_groupwise_scaled_gemm(\n                x_fp8,\n                x_scale,\n                w,\n                w_scale,\n                block_size,\n                x.dtype,\n            )\n            return x_fp8, x_scale, o\n\n    mod, _, ext_mods = TestModule().export_tvm(\n        spec={\n            \"triton_gemm\": {\n                \"x\": spec.Tensor((\"m\", K), dtype),\n                \"w\": spec.Tensor((N, K), fp8_dtype),\n                \"w_scale\": spec.Tensor(\n                    (\n                        (N + block_size[0] - 1) // block_size[0],\n                        (K + block_size[1] - 1) // block_size[1],\n                    ),\n                    \"float32\",\n                ),\n            },\n        },\n        allow_extern=True,\n    )\n    mod = DispatchTritonKernel(target)(mod)  # type: ignore\n    exec = relax.build(\n        mod,\n        target=target,\n        relax_pipeline=relax.backend.cuda.get_default_pipeline(target),\n    )\n    vm = relax.VirtualMachine(exec, device)\n\n    x_torch = torch.randn(M, K, dtype=getattr(torch, dtype), device=torch_device)\n    w_full_torch = torch.randn(N, K, dtype=getattr(torch, dtype), device=torch_device)\n    w_torch, w_scale_torch = blockwise_quant_fp8(w_full_torch, block_size, torch_fp8_dtype)\n    x_torch, x_fp8_torch, x_scale_torch = rowwise_quant_fp8(x_torch, block_size, torch_fp8_dtype)\n    o_torch = blockwise_matmul(x_fp8_torch, x_scale_torch, w_torch, w_scale_torch, x_torch.dtype)\n    x_tvm = tvm.runtime.tensor(x_torch.view(torch.float16).cpu().numpy().view(dtype), device=device)\n    w_tvm = tvm.runtime.tensor(\n        w_torch.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device\n    )\n    w_scale_tvm = tvm.runtime.tensor(w_scale_torch.cpu().numpy(), device=device)\n    x_fp8_tvm, x_scale_tvm, o_tvm = vm[\"triton_gemm\"](x_tvm, w_tvm, w_scale_tvm)\n    x_fp8_tvm = x_fp8_tvm.numpy()\n    x_scale_tvm = x_scale_tvm.numpy()\n    o_tvm = o_tvm.numpy()\n    np.testing.assert_allclose(\n        x_fp8_tvm,\n        x_fp8_torch.view(torch.uint8).cpu().numpy().view(fp8_dtype),\n        atol=1e-1,\n        rtol=1e-1,\n    )\n    np.testing.assert_allclose(x_scale_tvm, x_scale_torch.cpu().numpy(), atol=1e-5, rtol=1e-5)\n    atol = 0.5\n    rtol = 1e-4\n    o_tvm_flat = o_tvm.flatten()\n    o_torch_flat = o_torch.view(torch.float16).cpu().numpy().view(dtype).flatten()\n    failed_indices = np.where(\n        np.abs(o_tvm_flat - o_torch_flat) > (atol + rtol * np.abs(o_torch_flat))\n    )[0]\n    if len(failed_indices) > 0:\n        print(f\"failed_indices: {failed_indices}, size: {len(failed_indices)}\")\n        print(f\"o_tvm_flat[failed_indices]: {o_tvm_flat[failed_indices]}\")\n        print(f\"o_torch_flat[failed_indices]: {o_torch_flat[failed_indices]}\")\n    np.testing.assert_allclose(\n        o_tvm,\n        o_torch.view(torch.float16).cpu().numpy().view(dtype),\n        atol=atol,\n        rtol=rtol,\n    )\n\n\ndef test_fp8_block_group_matmul_cutlass(M: int, N: int, K: int, dtype: str):\n    num_experts = 256\n    top_k = 8\n\n    device = tvm.cuda()\n    target = tvm.target.Target.from_device(device)\n\n    class TestModule(nn.Module):\n        def __init__(self):\n            pass\n\n        def cutlass_group_gemm(\n            self,\n            x: nn.Tensor,\n            w: nn.Tensor,\n            w_scale: nn.Tensor,\n            indptr: nn.Tensor,\n        ):\n            e, n, k = w.shape\n            m = x.shape[0]\n            assert e == num_experts\n            assert (n + block_size[0] - 1) // block_size[0] == w_scale.shape[1]\n            assert (k + block_size[1] - 1) // block_size[1] == w_scale.shape[2]\n            assert x.shape[1] == k\n            x_fp8, x_scale = rowwise_group_quant_fp8(\n                x, block_size[1], w.dtype, transpose_scale=False\n            )\n            assert x_fp8.dtype == w.dtype\n            assert x_scale.dtype == \"float32\"\n            o = cutlass.fp8_groupwise_scaled_group_gemm(\n                x_fp8,\n                x_scale,\n                w,\n                w_scale,\n                indptr,\n                block_size,\n                x.dtype,\n            )\n            return x_fp8, x_scale, o\n\n    mod, _, ext_mods = TestModule().export_tvm(\n        spec={\n            \"cutlass_group_gemm\": {\n                \"x\": spec.Tensor((\"m\", K), dtype),\n                \"w\": spec.Tensor((num_experts, N, K), fp8_dtype),\n                \"w_scale\": spec.Tensor(\n                    (\n                        num_experts,\n                        (N + block_size[0] - 1) // block_size[0],\n                        (K + block_size[1] - 1) // block_size[1],\n                    ),\n                    \"float32\",\n                ),\n                \"indptr\": spec.Tensor((num_experts,), \"int64\"),\n            },\n        },\n        allow_extern=True,\n    )\n    exec = relax.build(\n        mod,\n        target=target,\n        relax_pipeline=relax.backend.cuda.get_default_pipeline(target),\n    )\n    vm = relax.VirtualMachine(exec, device)\n\n    # Randomly sample `top_k` experts for each token with pytorch\n    expert_choices = torch.randint(\n        0, num_experts, (M * top_k,), device=torch_device, dtype=torch.int32\n    )\n\n    factor = 1\n    # Balance so that the number of tokens for each expert is a multiple of `factor`\n    token_balance = 0\n    num_tokens_list = [int((expert_choices == i).sum().to(\"cpu\")) for i in range(num_experts)]\n    for i in range(num_experts):\n        if token_balance > 0:\n            diff = min(token_balance, num_tokens_list[i])\n            num_tokens_list[i] -= diff\n            token_balance -= diff\n        if num_tokens_list[i] % factor != 0:\n            token_balance += factor - num_tokens_list[i] % factor\n            num_tokens_list[i] += factor - num_tokens_list[i] % factor\n    assert sum(num_tokens_list) == M * top_k\n\n    indptr = torch.zeros(num_experts + 1, device=torch_device, dtype=torch.int64)\n    for i in range(num_experts):\n        indptr[i + 1] = indptr[i] + (expert_choices == i).sum()\n    token_ids_list = []\n    for i in range(num_experts):\n        # Get the indices of the tokens that belong to the i-th expert\n        token_ids = torch.where(expert_choices == i)[0]\n        token_ids_list.append(token_ids)\n\n    x_torch = torch.randn(M * top_k, K, dtype=getattr(torch, dtype), device=torch_device)\n    w_full_torch = torch.randn(num_experts, N, K, dtype=getattr(torch, dtype), device=torch_device)\n    w_torch, w_scale_torch = blockwise_quant_fp8(w_full_torch, block_size, torch_fp8_dtype)\n    x_torch, x_fp8_torch, x_scale_torch = rowwise_quant_fp8(x_torch, block_size, torch_fp8_dtype)\n    o_torch = blockwise_group_matmul(\n        x_fp8_torch,\n        x_scale_torch,\n        w_torch,\n        w_scale_torch,\n        indptr,\n        x_torch.dtype,\n    )\n    x_tvm = tvm.runtime.tensor(x_torch.view(torch.float16).cpu().numpy().view(dtype), device=device)\n    w_tvm = tvm.runtime.tensor(\n        w_torch.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device\n    )\n    w_scale_tvm = tvm.runtime.tensor(w_scale_torch.cpu().numpy(), device=device)\n    indptr_tvm = tvm.runtime.tensor(indptr[1:].cpu().numpy(), device=device)\n    x_fp8_tvm, x_scale_tvm, o_tvm = vm[\"cutlass_group_gemm\"](\n        x_tvm,\n        w_tvm,\n        w_scale_tvm,\n        indptr_tvm,\n    )\n    x_fp8_tvm = x_fp8_tvm.numpy()\n    x_scale_tvm = x_scale_tvm.numpy()\n    o_tvm = o_tvm.numpy()\n    np.testing.assert_allclose(\n        x_fp8_tvm,\n        x_fp8_torch.view(torch.uint8).cpu().numpy().view(fp8_dtype),\n        atol=1e-1,\n        rtol=1e-1,\n    )\n    np.testing.assert_allclose(x_scale_tvm, x_scale_torch.cpu().numpy(), atol=1e-5, rtol=1e-5)\n    atol = 0.5\n    rtol = 1e-4\n    o_tvm_flat = o_tvm.flatten()\n    o_torch_flat = o_torch.view(torch.float16).cpu().numpy().view(dtype).flatten()\n    failed_indices = np.where(\n        np.abs(o_tvm_flat - o_torch_flat) > (atol + rtol * np.abs(o_torch_flat))\n    )[0]\n    if len(failed_indices) > 0:\n        print(f\"failed_indices: {failed_indices}, size: {len(failed_indices)}\")\n        print(f\"o_tvm_flat[failed_indices]: {o_tvm_flat[failed_indices]}\")\n        print(f\"o_torch_flat[failed_indices]: {o_torch_flat[failed_indices]}\")\n    np.testing.assert_allclose(\n        o_tvm,\n        o_torch.view(torch.float16).cpu().numpy().view(dtype),\n        atol=atol,\n        rtol=rtol,\n    )\n\n\ndef test_fp8_block_group_matmul_triton(M: int, N: int, K: int, dtype: str):\n    num_experts = 256\n    top_k = 8\n\n    device = tvm.cuda()\n    target = tvm.target.Target.from_device(device)\n\n    class TestModule(nn.Module):\n        def __init__(self):\n            pass\n\n        def triton_group_gemm(\n            self,\n            x: nn.Tensor,\n            w: nn.Tensor,\n            w_scale: nn.Tensor,\n            indptr: nn.Tensor,\n        ):\n            e, n, k = w.shape\n            m = x.shape[0]\n            assert e == num_experts\n            assert (n + block_size[0] - 1) // block_size[0] == w_scale.shape[1]\n            assert (k + block_size[1] - 1) // block_size[1] == w_scale.shape[2]\n            assert x.shape[1] == k\n            x_fp8, x_scale = rowwise_group_quant_fp8(\n                x, block_size[1], w.dtype, transpose_scale=False\n            )\n            assert x_fp8.dtype == w.dtype\n            assert x_scale.dtype == \"float32\"\n            o = triton.fp8_groupwise_scaled_group_gemm(\n                x_fp8,\n                x_scale,\n                w,\n                w_scale,\n                indptr,\n                block_size,\n                x.dtype,\n            )\n            return x_fp8, x_scale, o\n\n    mod, _, ext_mods = TestModule().export_tvm(\n        spec={\n            \"triton_group_gemm\": {\n                \"x\": spec.Tensor((\"m\", K), dtype),\n                \"w\": spec.Tensor((num_experts, N, K), fp8_dtype),\n                \"w_scale\": spec.Tensor(\n                    (\n                        num_experts,\n                        (N + block_size[0] - 1) // block_size[0],\n                        (K + block_size[1] - 1) // block_size[1],\n                    ),\n                    \"float32\",\n                ),\n                \"indptr\": spec.Tensor((num_experts + 1,), \"int32\"),\n            },\n        },\n        allow_extern=True,\n    )\n    mod = DispatchTritonKernel(target)(mod)  # type: ignore\n    exec = relax.build(\n        mod,\n        target=target,\n        relax_pipeline=relax.backend.cuda.get_default_pipeline(target),\n    )\n    vm = relax.VirtualMachine(exec, device)\n\n    # Randomly sample `top_k` experts for each token with pytorch\n    expert_choices = torch.randint(\n        0, num_experts, (M * top_k,), device=torch_device, dtype=torch.int32\n    )\n\n    indptr = torch.zeros(num_experts + 1, device=torch_device, dtype=torch.int32)\n    for i in range(num_experts):\n        indptr[i + 1] = indptr[i] + (expert_choices == i).sum()\n    token_ids_list = []\n    for i in range(num_experts):\n        # Get the indices of the tokens that belong to the i-th expert\n        token_ids = torch.where(expert_choices == i)[0]\n        token_ids_list.append(token_ids)\n\n    x_torch = torch.randn(M * top_k, K, dtype=getattr(torch, dtype), device=torch_device)\n    w_full_torch = torch.randn(num_experts, N, K, dtype=getattr(torch, dtype), device=torch_device)\n    w_torch, w_scale_torch = blockwise_quant_fp8(w_full_torch, block_size, torch_fp8_dtype)\n    x_torch, x_fp8_torch, x_scale_torch = rowwise_quant_fp8(x_torch, block_size, torch_fp8_dtype)\n    o_torch = blockwise_group_matmul(\n        x_fp8_torch,\n        x_scale_torch,\n        w_torch,\n        w_scale_torch,\n        indptr,\n        x_torch.dtype,\n    )\n    x_tvm = tvm.runtime.tensor(x_torch.view(torch.float16).cpu().numpy().view(dtype), device=device)\n    w_tvm = tvm.runtime.tensor(\n        w_torch.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device\n    )\n    w_scale_tvm = tvm.runtime.tensor(w_scale_torch.cpu().numpy(), device=device)\n    indptr_tvm = tvm.runtime.tensor(indptr.cpu().numpy(), device=device)\n    x_fp8_tvm, x_scale_tvm, o_tvm = vm[\"triton_group_gemm\"](\n        x_tvm,\n        w_tvm,\n        w_scale_tvm,\n        indptr_tvm,\n    )\n    x_fp8_tvm = x_fp8_tvm.numpy()\n    x_scale_tvm = x_scale_tvm.numpy()\n    o_tvm = o_tvm.numpy()\n    np.testing.assert_allclose(\n        x_fp8_tvm,\n        x_fp8_torch.view(torch.uint8).cpu().numpy().view(fp8_dtype),\n        atol=1e-1,\n        rtol=1e-1,\n    )\n    np.testing.assert_allclose(x_scale_tvm, x_scale_torch.cpu().numpy(), atol=1e-5, rtol=1e-5)\n    atol = 0.5\n    rtol = 1e-4\n    o_tvm_flat = o_tvm.flatten()\n    o_torch_flat = o_torch.view(torch.float16).cpu().numpy().view(dtype).flatten()\n    failed_indices = np.where(\n        np.abs(o_tvm_flat - o_torch_flat) > (atol + rtol * np.abs(o_torch_flat))\n    )[0]\n    if len(failed_indices) > 0:\n        print(f\"failed_indices: {failed_indices}, size: {len(failed_indices)}\")\n        print(f\"o_tvm_flat[failed_indices]: {o_tvm_flat[failed_indices]}\")\n        print(f\"o_torch_flat[failed_indices]: {o_torch_flat[failed_indices]}\")\n    np.testing.assert_allclose(\n        o_tvm,\n        o_torch.view(torch.float16).cpu().numpy().view(dtype),\n        atol=atol,\n        rtol=rtol,\n    )\n\n\ndef test_fp8_block_bmm_cutlass(M: int, N: int, K: int, H: int, dtype: str):\n    class TestModule(nn.Module):\n        def __init__(self):\n            pass\n\n        def cutlass_bmm(self, x: nn.Tensor, w: nn.Tensor, w_scale: nn.Tensor):\n            _, n, k = w.shape\n            assert w.shape[0] == x.shape[0] == H\n            assert n % block_size[0] == 0\n            assert k % block_size[1] == 0\n            assert n // block_size[0] == w_scale.shape[1]\n            assert k // block_size[1] == w_scale.shape[2]\n            assert x.shape[2] == k\n            o = batch_matmul.quantized_bmm(x, w, w_scale, block_size)\n            return o\n\n    mod, _, ext_mods = TestModule().export_tvm(\n        spec={\n            \"cutlass_bmm\": {\n                \"x\": spec.Tensor((H, \"m\", K), dtype),\n                \"w\": spec.Tensor((H, N, K), fp8_dtype),\n                \"w_scale\": spec.Tensor(\n                    (\n                        H,\n                        (N + block_size[0] - 1) // block_size[0],\n                        (K + block_size[1] - 1) // block_size[1],\n                    ),\n                    \"float32\",\n                ),\n            },\n        },\n        allow_extern=True,\n    )\n    device = tvm.cuda()\n    target = tvm.target.Target.from_device(device)\n    exec = relax.build(\n        mod,\n        target=target,\n        relax_pipeline=relax.backend.cuda.get_default_pipeline(target),\n    )\n    vm = relax.VirtualMachine(exec, device)\n\n    x_torch = torch.randn(H, M, K, dtype=getattr(torch, dtype), device=torch_device)\n    w_full_torch = torch.randn(H, N, K, dtype=getattr(torch, dtype), device=torch_device)\n    w_torch, w_scale_torch = blockwise_quant_fp8(w_full_torch, block_size, torch_fp8_dtype)\n    x_torch, x_fp8_torch, x_scale_torch = rowwise_quant_fp8(x_torch, block_size, torch_fp8_dtype)\n    o_torch = blockwise_bmm(x_fp8_torch, x_scale_torch, w_torch, w_scale_torch, x_torch.dtype)\n    x_tvm = tvm.runtime.tensor(x_torch.view(torch.float16).cpu().numpy().view(dtype), device=device)\n    w_tvm = tvm.runtime.tensor(\n        w_torch.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device\n    )\n    w_scale_tvm = tvm.runtime.tensor(w_scale_torch.cpu().numpy(), device=device)\n    o_tvm = vm[\"cutlass_bmm\"](x_tvm, w_tvm, w_scale_tvm)\n    o_tvm = o_tvm.numpy()\n    atol = 0.5\n    rtol = 1e-4\n    o_tvm_flat = o_tvm.flatten()\n    o_torch_flat = o_torch.view(torch.float16).cpu().numpy().view(dtype).flatten()\n    failed_indices = np.where(\n        np.abs(o_tvm_flat - o_torch_flat) > (atol + rtol * np.abs(o_torch_flat))\n    )[0]\n    if len(failed_indices) > 0:\n        print(f\"failed_indices: {failed_indices}, size: {len(failed_indices)}\")\n        print(f\"o_tvm_flat[failed_indices]: {o_tvm_flat[failed_indices]}\")\n        print(f\"o_torch_flat[failed_indices]: {o_torch_flat[failed_indices]}\")\n    np.testing.assert_allclose(\n        o_tvm,\n        o_torch.view(torch.float16).cpu().numpy().view(dtype),\n        atol=atol,\n        rtol=rtol,\n    )\n\n\ndef test_fp8_block_gemv_tir(N: int, K: int, up: bool, dtype: str):\n    num_experts = 256\n    top_k = 8\n    M = 1 if up else top_k\n\n    device = tvm.cuda()\n    target = tvm.target.Target.from_device(device)\n\n    class TestModule(nn.Module):\n        def __init__(self):\n            pass\n\n        def tir_moe_gemv(\n            self,\n            x: nn.Tensor,\n            w: nn.Tensor,\n            w_scale: nn.Tensor,\n            expert_indices: nn.Tensor,\n        ):\n            e, n, k = w.shape\n            m = x.shape[0]\n            assert e == num_experts\n            assert (n + block_size[0] - 1) // block_size[0] == w_scale.shape[1]\n            assert (k + block_size[1] - 1) // block_size[1] == w_scale.shape[2]\n            assert x.shape[1] == k\n            o = moe_matmul.dequantize_block_scale_float8_gemv(\n                x, w, w_scale, expert_indices, block_size, x.dtype\n            )\n            return o\n\n    mod, _, ext_mods = TestModule().export_tvm(\n        spec={\n            \"tir_moe_gemv\": {\n                \"x\": spec.Tensor((M, K), dtype),\n                \"w\": spec.Tensor((num_experts, N, K), fp8_dtype),\n                \"w_scale\": spec.Tensor(\n                    (\n                        num_experts,\n                        (N + block_size[0] - 1) // block_size[0],\n                        (K + block_size[1] - 1) // block_size[1],\n                    ),\n                    \"float32\",\n                ),\n                \"expert_indices\": spec.Tensor((1, top_k), \"int32\"),\n            },\n        },\n        allow_extern=True,\n    )\n    with target:\n        mod = dl.ApplyDefaultSchedule(\n            dl.gpu.Matmul(),\n            dl.gpu.GEMV(),\n            dl.gpu.Reduction(),\n            dl.gpu.GeneralReduction(),\n            dl.gpu.Fallback(),\n        )(mod)\n    exec = relax.build(\n        mod,\n        target=target,\n        relax_pipeline=relax.backend.cuda.get_default_pipeline(target),\n    )\n    vm = relax.VirtualMachine(exec, device)\n\n    # Randomly sample `top_k` experts for each token with pytorch\n    expert_choices = torch.randint(0, num_experts, (top_k,), device=torch_device, dtype=torch.int32)\n    indptr = torch.zeros(num_experts + 1, device=torch_device, dtype=torch.int32)\n    for i in range(num_experts):\n        indptr[i + 1] = indptr[i] + (expert_choices == i).sum()\n    token_ids_list = []\n    for i in range(num_experts):\n        # Get the indices of the tokens that belong to the i-th expert\n        token_ids = torch.where(expert_choices == i)[0]\n        token_ids_list.append(token_ids)\n\n    x_torch = torch.randn(M, K, dtype=getattr(torch, dtype), device=torch_device)\n    w_full_torch = torch.randn(num_experts, N, K, dtype=getattr(torch, dtype), device=torch_device)\n    w_torch, w_scale_torch = blockwise_quant_fp8(w_full_torch, block_size, torch_fp8_dtype)\n    x_input_torch = torch.repeat_interleave(x_torch, top_k, dim=0) if up else x_torch\n    o_torch = blockwise_group_matmul_unquantized(\n        x_input_torch, w_torch, w_scale_torch, expert_choices\n    )\n    x_tvm = tvm.runtime.tensor(x_torch.view(torch.float16).cpu().numpy().view(dtype), device=device)\n    w_tvm = tvm.runtime.tensor(\n        w_torch.view(torch.uint8).cpu().numpy().view(fp8_dtype), device=device\n    )\n    w_scale_tvm = tvm.runtime.tensor(w_scale_torch.cpu().numpy(), device=device)\n    expert_choices = tvm.runtime.tensor(\n        expert_choices.reshape(1, top_k).cpu().numpy(), device=device\n    )\n    o_tvm = vm[\"tir_moe_gemv\"](x_tvm, w_tvm, w_scale_tvm, expert_choices)\n    o_tvm = o_tvm.numpy()\n    atol = 0.5\n    rtol = 1e-4\n    o_tvm_flat = o_tvm.flatten()\n    o_torch_flat = o_torch.view(torch.float16).cpu().numpy().view(dtype).flatten()\n    failed_indices = np.where(\n        np.abs(o_tvm_flat - o_torch_flat) > (atol + rtol * np.abs(o_torch_flat))\n    )[0]\n    if len(failed_indices) > 0:\n        print(f\"failed_indices: {failed_indices}, size: {len(failed_indices)}\")\n        print(f\"o_tvm_flat[failed_indices]: {o_tvm_flat[failed_indices]}\")\n        print(f\"o_torch_flat[failed_indices]: {o_torch_flat[failed_indices]}\")\n    np.testing.assert_allclose(\n        o_tvm,\n        o_torch.view(torch.float16).cpu().numpy().view(dtype),\n        atol=atol,\n        rtol=rtol,\n    )\n\n\ndef blockwise_matmul(\n    x_fp8_torch: torch.Tensor,\n    x_scale_torch: torch.Tensor,\n    w_torch: torch.Tensor,\n    w_scale_torch: torch.Tensor,\n    dtype,\n):\n    o_torch = torch.zeros(\n        (x_fp8_torch.shape[0], w_torch.shape[0]), dtype=dtype, device=torch_device\n    )\n    for j in range(w_scale_torch.shape[0]):\n        for k in range(w_scale_torch.shape[1]):\n            o_torch[\n                :,\n                j * block_size[0] : min((j + 1) * block_size[0], w_torch.shape[0]),\n            ] += (\n                torch.matmul(\n                    x_fp8_torch[\n                        :,\n                        k * block_size[1] : min((k + 1) * block_size[1], x_fp8_torch.shape[1]),\n                    ].to(dtype),\n                    w_torch[\n                        j * block_size[0] : min((j + 1) * block_size[0], w_torch.shape[0]),\n                        k * block_size[1] : min((k + 1) * block_size[1], w_torch.shape[1]),\n                    ].T.to(dtype),\n                )\n                * x_scale_torch[:, k : k + 1]\n                * w_scale_torch[j, k]\n            )\n    return o_torch\n\n\ndef blockwise_group_matmul(\n    x_fp8_torch: torch.Tensor,\n    x_scale_torch: torch.Tensor,\n    w_torch: torch.Tensor,\n    w_scale_torch: torch.Tensor,\n    indptr: torch.Tensor,\n    dtype,\n):\n    o_torch = torch.zeros(\n        (x_fp8_torch.shape[0], w_torch.shape[1]), dtype=dtype, device=torch_device\n    )\n    for e in range(w_scale_torch.shape[0]):\n        if indptr[e + 1] - indptr[e] == 0:\n            continue\n        indices = slice(indptr[e], indptr[e + 1])\n        for j in range(w_scale_torch.shape[1]):\n            for k in range(w_scale_torch.shape[2]):\n                o_torch[\n                    indices,\n                    j * block_size[0] : min((j + 1) * block_size[0], w_torch.shape[1]),\n                ] += (\n                    torch.matmul(\n                        x_fp8_torch.to(dtype)[\n                            indices,\n                            k * block_size[1] : min((k + 1) * block_size[1], x_fp8_torch.shape[1]),\n                        ],\n                        w_torch[\n                            e,\n                            j * block_size[0] : min((j + 1) * block_size[0], w_torch.shape[1]),\n                            k * block_size[1] : min((k + 1) * block_size[1], w_torch.shape[2]),\n                        ].T.to(dtype),\n                    )\n                    * x_scale_torch[indices, k : k + 1]\n                    * w_scale_torch[e, j, k]\n                )\n    return o_torch\n\n\ndef blockwise_group_matmul_unquantized(\n    x_torch: torch.Tensor,\n    w_torch: torch.Tensor,\n    w_scale_torch: torch.Tensor,\n    expert_choices: torch.Tensor,\n):\n    o_torch = torch.zeros(\n        (x_torch.shape[0], w_torch.shape[1]), dtype=x_torch.dtype, device=torch_device\n    )\n    for i, e in enumerate(expert_choices):\n        for j in range(w_scale_torch.shape[1]):\n            for k in range(w_scale_torch.shape[2]):\n                o_torch[\n                    i,\n                    j * block_size[0] : min((j + 1) * block_size[0], w_torch.shape[1]),\n                ] += torch.matmul(\n                    x_torch[\n                        i,\n                        k * block_size[1] : min((k + 1) * block_size[1], x_torch.shape[1]),\n                    ],\n                    w_torch[\n                        e,\n                        j * block_size[0] : min((j + 1) * block_size[0], w_torch.shape[1]),\n                        k * block_size[1] : min((k + 1) * block_size[1], w_torch.shape[2]),\n                    ].T.to(x_torch.dtype)\n                    * w_scale_torch[e, j, k].to(x_torch.dtype),\n                )\n    return o_torch\n\n\ndef blockwise_bmm(\n    x_fp8_torch: torch.Tensor,\n    x_scale_torch: torch.Tensor,\n    w_torch: torch.Tensor,\n    w_scale_torch: torch.Tensor,\n    dtype,\n):\n    o_torch = torch.zeros(\n        (x_fp8_torch.shape[0], x_fp8_torch.shape[1], w_torch.shape[1]),\n        dtype=dtype,\n        device=torch_device,\n    )\n    for j in range(w_scale_torch.shape[1]):\n        for k in range(w_scale_torch.shape[2]):\n            o_torch[\n                ...,\n                j * block_size[0] : min((j + 1) * block_size[0], w_torch.shape[1]),\n            ] += (\n                torch.bmm(\n                    x_fp8_torch[\n                        ...,\n                        k * block_size[1] : min((k + 1) * block_size[1], x_fp8_torch.shape[2]),\n                    ].to(dtype),\n                    w_torch[\n                        ...,\n                        j * block_size[0] : min((j + 1) * block_size[0], w_torch.shape[1]),\n                        k * block_size[1] : min((k + 1) * block_size[1], w_torch.shape[2]),\n                    ]\n                    .transpose(1, 2)\n                    .to(dtype),\n                )\n                * x_scale_torch[..., k : k + 1]\n                * w_scale_torch[..., j : j + 1, k : k + 1]\n            )\n    return o_torch\n\n\ndef blockwise_quant_fp8(\n    w_full_torch: torch.Tensor, block_size: Tuple[int, int], quant_dtype: torch.dtype\n):\n    w_scale_shape = (\n        *w_full_torch.shape[:-2],\n        (w_full_torch.shape[-2] + block_size[0] - 1) // block_size[0],\n        (w_full_torch.shape[-1] + block_size[1] - 1) // block_size[1],\n    )\n    # For each (block_size[0], block_size[1]) block, compute the max abs value of `w_full_torch`\n    w_max_abs_torch = torch.zeros(w_scale_shape, dtype=torch.float32, device=torch_device)\n    for i in range(w_scale_shape[-2]):\n        for j in range(w_scale_shape[-1]):\n            w_max_abs_torch[..., i, j] = torch.max(\n                torch.abs(\n                    w_full_torch[\n                        ...,\n                        i * block_size[0] : min((i + 1) * block_size[0], w_full_torch.shape[-2]),\n                        j * block_size[1] : min((j + 1) * block_size[1], w_full_torch.shape[-1]),\n                    ]\n                ).flatten(-2, -1),\n                dim=-1,\n            )[0]\n    # Scale is the `w_max_abs_torch` divided by the max value of quant_dtype in ml_dtypes\n    fp8_max = float(ml_dtypes.finfo(fp8_dtype).max)\n    w_scale_torch = w_max_abs_torch / fp8_max\n    # `w_torch` is the `w_full_torch` divided by the `w_scale_torch` (with block awareness),\n    # clamped to (-fp8_max, fp8_max), and cast to `quant_dtype`\n    w_torch = torch.zeros_like(w_full_torch, dtype=quant_dtype, device=torch_device)\n    if len(w_scale_shape) == 2:\n        for i in range(w_scale_shape[-2]):\n            for j in range(w_scale_shape[-1]):\n                w_torch[\n                    i * block_size[0] : min((i + 1) * block_size[0], w_full_torch.shape[-2]),\n                    j * block_size[1] : min((j + 1) * block_size[1], w_full_torch.shape[-1]),\n                ] = torch.clamp(\n                    w_full_torch[\n                        i * block_size[0] : min((i + 1) * block_size[0], w_full_torch.shape[-2]),\n                        j * block_size[1] : min((j + 1) * block_size[1], w_full_torch.shape[-1]),\n                    ]\n                    / w_scale_torch[..., i, j],\n                    -fp8_max,\n                    fp8_max,\n                )\n    else:\n        for e in range(w_scale_shape[0]):\n            for i in range(w_scale_shape[-2]):\n                for j in range(w_scale_shape[-1]):\n                    w_torch[\n                        e,\n                        i * block_size[0] : min((i + 1) * block_size[0], w_full_torch.shape[-2]),\n                        j * block_size[1] : min((j + 1) * block_size[1], w_full_torch.shape[-1]),\n                    ] = torch.clamp(\n                        w_full_torch[\n                            e,\n                            i\n                            * block_size[0] : min((i + 1) * block_size[0], w_full_torch.shape[-2]),\n                            j\n                            * block_size[1] : min((j + 1) * block_size[1], w_full_torch.shape[-1]),\n                        ]\n                        / w_scale_torch[e, i, j],\n                        -fp8_max,\n                        fp8_max,\n                    )\n\n    w_scale_torch = (\n        torch.rand(w_scale_torch.shape, dtype=torch.float32, device=torch_device) / fp8_max\n    )\n    return w_torch, w_scale_torch\n\n\ndef rowwise_quant_fp8(\n    x_full_torch: torch.Tensor, block_size: Tuple[int, int], quant_dtype: torch.dtype\n):\n    x_scale_shape = (\n        *x_full_torch.shape[:-1],\n        (x_full_torch.shape[-1] + block_size[1] - 1) // block_size[1],\n    )\n    # For each (block_size[1]) block, compute the max abs value of `w_full_torch`\n    x_max_abs_torch = torch.zeros(x_scale_shape, dtype=torch.float32, device=torch_device)\n    for i in range(x_scale_shape[-1]):\n        x_max_abs_torch[..., i] = torch.max(\n            torch.abs(\n                x_full_torch[\n                    ...,\n                    i * block_size[1] : min((i + 1) * block_size[1], x_full_torch.shape[-1]),\n                ]\n            ),\n            dim=-1,\n        )[0]\n    # Scale is the `x_max_abs_torch` divided by the max value of quant_dtype in ml_dtypes\n    fp8_max = float(ml_dtypes.finfo(fp8_dtype).max)\n    x_scale_torch = x_max_abs_torch / fp8_max\n    # `x_torch` is the `x_full_torch` divided by the `x_scale_torch` (with block awareness),\n    # clamped to (-fp8_max, fp8_max), and cast to `quant_dtype`\n    x_torch = torch.zeros_like(x_full_torch, dtype=quant_dtype, device=torch_device)\n    for i in range(x_scale_shape[-1]):\n        x_torch[\n            ...,\n            i * block_size[1] : min((i + 1) * block_size[1], x_full_torch.shape[-1]),\n        ] = torch.clamp(\n            x_full_torch[\n                ...,\n                i * block_size[1] : min((i + 1) * block_size[1], x_full_torch.shape[-1]),\n            ]\n            / x_scale_torch[..., i : i + 1],\n            -fp8_max,\n            fp8_max,\n        )\n\n    x_scale_torch = (\n        torch.rand(x_scale_torch.shape, dtype=torch.float32, device=torch_device) / fp8_max\n    )\n    for i in range(x_scale_shape[-1]):\n        x_full_torch[\n            ...,\n            i * block_size[1] : min((i + 1) * block_size[1], x_full_torch.shape[-1]),\n        ] = (\n            x_torch[\n                ...,\n                i * block_size[1] : min((i + 1) * block_size[1], x_full_torch.shape[-1]),\n            ].to(x_scale_torch.dtype)\n            * x_scale_torch[..., i : i + 1]\n        )\n    return x_full_torch, x_torch, x_scale_torch\n\n\n@pytest.mark.skip(reason=\"Test requiring SM90a\")\ndef test_cutlass_gemm():\n    # Cutlass GEMM\n    for M, (N, K), dtype in product(\n        [4, 128, 256, 1024, 2112],\n        [\n            (4608, 896),\n            (896, 2304),\n            (3072, 896),\n            (512, 896),\n            (3072, 512),\n            (4096, 512),\n            (896, 2048),\n            (129280, 896),\n        ],\n        [\"bfloat16\"],\n    ):\n        print(f\"Cutlass, M: {M}, N: {N}, K: {K}, dtype: {dtype}\")\n        test_fp8_block_matmul_cutlass(M, N, K, dtype)\n\n\n@pytest.mark.skip(reason=\"Test requiring SM90a\")\ndef test_triton_gemm():\n    # Triton GEMM\n    for M, (N, K), dtype in product(\n        [1, 128, 256, 1024, 2111],\n        [\n            (4608, 896),\n            (896, 576),\n            (896, 2304),\n        ],\n        [\"bfloat16\"],\n    ):\n        print(f\"Triton, M: {M}, N: {N}, K: {K}, dtype: {dtype}\")\n        test_fp8_block_matmul_triton(M, N, K, dtype)\n\n\n@pytest.mark.skip(reason=\"Test requiring SM90a\")\ndef test_cutlass_group_gemm():\n    # Cutlass group GEMM\n    for M, (N, K), dtype in product(\n        [1, 128, 256, 1024, 2111],\n        [\n            (512, 896),\n            (896, 256),\n        ],\n        [\"bfloat16\"],\n    ):\n        print(f\"Cutlass group gemm, M: {M}, N: {N}, K: {K}, dtype: {dtype}\")\n        test_fp8_block_group_matmul_cutlass(M, N, K, dtype)\n\n\n@pytest.mark.skip(reason=\"Test requiring SM90a\")\ndef test_triton_group_gemm():\n    # Triton group GEMM\n    for M, (N, K), dtype in product(\n        [1, 128, 256, 1024, 2111],\n        [\n            (512, 896),\n            (896, 256),\n        ],\n        [\"bfloat16\"],\n    ):\n        print(f\"Triton group gemm, M: {M}, N: {N}, K: {K}, dtype: {dtype}\")\n        test_fp8_block_group_matmul_triton(M, N, K, dtype)\n\n\n@pytest.mark.skip(reason=\"Test requiring SM90a\")\ndef test_cutlass_bmm():\n    # Cutlass BMM\n    for M, H, (N, K), dtype in product(\n        [4, 128, 256, 1024, 2112],\n        [16, 64, 128],\n        [\n            (512, 128),\n            (128, 512),\n        ],\n        [\"bfloat16\"],\n    ):\n        print(f\"Cutlass BMM, M: {M}, N: {N}, K: {K}, H: {H}, dtype: {dtype}\")\n        test_fp8_block_bmm_cutlass(M, N, K, H, dtype)\n\n\n@pytest.mark.skip(reason=\"Test requiring SM90a\")\ndef test_tir_moe_gemv():\n    # TIR MoE GEMV\n    for (N, K), up, dtype in product(\n        [(512, 896), (896, 256)],\n        [True, False],\n        [\"bfloat16\"],\n    ):\n        print(f\"TIR MoE GEMV, N: {N}, K: {K}, up: {up}, dtype: {dtype}\")\n        test_fp8_block_gemv_tir(N, K, up, dtype)\n\n\nif __name__ == \"__main__\":\n    test_cutlass_gemm()\n    test_triton_gemm()\n    test_cutlass_group_gemm()\n    test_triton_group_gemm()\n    test_cutlass_bmm()\n    test_tir_moe_gemv()\n"
  },
  {
    "path": "tests/python/op/test_mrope.py",
    "content": "import numpy as np\nimport pytest\n\ntvm = pytest.importorskip(\"tvm\")\nfrom tvm import relax\nfrom tvm.relax.frontend import nn\nfrom tvm.relax.frontend.nn import spec\nfrom tvm.runtime import tensor as tvm_tensor\n\nfrom mlc_llm.op import (\n    MultimodalRotaryEmbedding,\n    VisionPositionMetadata,\n    apply_multimodal_rotary_pos_emb,\n    get_mrope_position_ids,\n)\n\n\ndef _numpy_rotate_half(x: np.ndarray) -> np.ndarray:\n    x1, x2 = np.split(x, 2, axis=-1)\n    return np.concatenate([-x2, x1], axis=-1)\n\n\ndef _numpy_apply_mrope(\n    q: np.ndarray,\n    k: np.ndarray,\n    position_ids: np.ndarray,\n    theta: float,\n    mrope_section: tuple[int, ...],\n) -> tuple[np.ndarray, np.ndarray]:\n    if position_ids.ndim != 3:\n        raise ValueError(f\"position_ids must be rank-3, got shape {position_ids.shape}\")\n    if position_ids.shape[0] == 3:\n        position_ids = np.transpose(position_ids, (1, 2, 0))\n    elif position_ids.shape[-1] != 3:\n        raise ValueError(\n            \"position_ids must have shape (batch, seq, 3) or (3, batch, seq), \"\n            f\"got {position_ids.shape}\"\n        )\n\n    head_dim = q.shape[-1]\n    inv_freq = 1.0 / (theta ** (np.arange(0, head_dim, 2, dtype=np.float32) / float(head_dim)))\n    pos = np.transpose(position_ids, (2, 0, 1))\n    inv = inv_freq.reshape(1, 1, -1, 1).astype(np.float32)\n    inv = np.broadcast_to(inv, (3, pos.shape[1], inv_freq.size, 1))\n    pos = pos.reshape(3, pos.shape[1], 1, pos.shape[2]).astype(np.float32)\n    freqs = np.matmul(inv, pos)\n    freqs = np.transpose(freqs, (0, 1, 3, 2))\n    emb = np.concatenate([freqs, freqs], axis=-1)\n    cos = np.cos(emb)\n    sin = np.sin(emb)\n    split_sizes = list(mrope_section) * 2\n    split_points = np.cumsum(split_sizes)[:-1]\n    cos_chunks = np.split(cos, split_points, axis=-1)\n    sin_chunks = np.split(sin, split_points, axis=-1)\n    cos = np.concatenate([chunk[idx % 3] for idx, chunk in enumerate(cos_chunks)], axis=-1)\n    sin = np.concatenate([chunk[idx % 3] for idx, chunk in enumerate(sin_chunks)], axis=-1)\n    cos = np.expand_dims(cos, axis=2)\n    sin = np.expand_dims(sin, axis=2)\n    q_out = q * cos + _numpy_rotate_half(q) * sin\n    k_out = k * cos + _numpy_rotate_half(k) * sin\n    return q_out, k_out\n\n\ndef _evaluate_tensor(expr):\n    mod = tvm.IRModule.from_expr(expr)\n    target = tvm.target.Target(\"llvm\")\n    ex = tvm.relax.build(mod, target)\n    vm = tvm.relax.VirtualMachine(ex, tvm.cpu())\n    return vm[\"main\"]().numpy()\n\n\ndef _run_mlc_mrope(\n    q_np: np.ndarray,\n    k_np: np.ndarray,\n    position_ids_np: np.ndarray,\n    theta: float,\n    mrope_section: tuple[int, ...],\n) -> tuple[np.ndarray, np.ndarray]:\n    class RopeModule(nn.Module):  # pylint: disable=too-few-public-methods\n        def __init__(self):\n            super().__init__()\n            self.rotary = MultimodalRotaryEmbedding(q_np.shape[-1], theta, mrope_section)\n\n        def forward(\n            self,\n            q: nn.Tensor,\n            k: nn.Tensor,\n            pos: nn.Tensor,\n        ):\n            \"\"\"Run MRoPE on test tensors and return rotated query/key outputs.\"\"\"\n            cos, sin = self.rotary(q, pos)\n            return apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section)\n\n    module = RopeModule()\n    mod, _, _ = module.export_tvm(\n        spec={\n            \"forward\": {\n                \"q\": spec.Tensor(q_np.shape, \"float32\"),\n                \"k\": spec.Tensor(k_np.shape, \"float32\"),\n                \"pos\": spec.Tensor(position_ids_np.shape, \"int64\"),\n            }\n        },\n        allow_extern=True,\n    )\n    target = tvm.target.Target(\"llvm\")\n    exec_mod = relax.build(mod, target=target)\n    vm = relax.VirtualMachine(exec_mod, tvm.cpu())\n    device = tvm.cpu()\n    q_nd = tvm_tensor(q_np.astype(\"float32\"), device=device)\n    k_nd = tvm_tensor(k_np.astype(\"float32\"), device=device)\n    pos_nd = tvm_tensor(position_ids_np.astype(\"int64\"), device=device)\n    out_q, out_k = vm[\"forward\"](q_nd, k_nd, pos_nd)\n    return out_q.numpy(), out_k.numpy()\n\n\ndef test_apply_mrope_matches_numpy_reference():\n    theta = 10000.0\n    mrope_section = (2, 2, 2)\n    batch, seq_len, heads, head_dim = 1, 4, 2, 12\n    rng = np.random.default_rng(0)\n    q_np = rng.standard_normal((batch, seq_len, heads, head_dim), dtype=np.float32)\n    k_np = rng.standard_normal((batch, seq_len, heads, head_dim), dtype=np.float32)\n    position_ids = np.zeros((batch, seq_len, 3), dtype=np.int64)\n    position_ids[0, :, 0] = np.arange(seq_len)\n    position_ids[0, :, 1] = np.arange(seq_len) * 2\n    position_ids[0, :, 2] = np.arange(seq_len) * 3\n\n    mlc_q, mlc_k = _run_mlc_mrope(q_np, k_np, position_ids, theta, mrope_section)\n    ref_q, ref_k = _numpy_apply_mrope(q_np, k_np, position_ids, theta, mrope_section)\n\n    np.testing.assert_allclose(mlc_q, ref_q, rtol=1e-5, atol=1e-5)\n    np.testing.assert_allclose(mlc_k, ref_k, rtol=1e-5, atol=1e-5)\n\n\ndef test_get_mrope_position_ids_text_only():\n    input_ids = np.array([[1, 2, 3, 0, 0]], dtype=np.int64)\n    attention_mask = np.array([[1, 1, 1, 0, 0]], dtype=np.int64)\n    meta = VisionPositionMetadata(\n        vision_start_token_id=1000,\n        image_token_id=1001,\n        video_token_id=1002,\n        spatial_merge_size=2,\n        tokens_per_second=4.0,\n    )\n    position_ids, deltas = get_mrope_position_ids(\n        input_ids,\n        meta,\n        attention_mask=attention_mask,\n        image_grid_thw=None,\n        video_grid_thw=None,\n        second_per_grid_ts=None,\n    )\n    expected = attention_mask.cumsum(axis=-1) - 1\n    expected = np.where(attention_mask == 0, 1, expected)\n    expected = np.expand_dims(expected, axis=0).repeat(3, axis=0)\n    np.testing.assert_array_equal(position_ids, expected)\n    np.testing.assert_array_equal(deltas, np.array([[-2]], dtype=np.int64))\n\n\ndef test_get_mrope_position_ids_single_image_block():\n    meta = VisionPositionMetadata(\n        vision_start_token_id=5000,\n        image_token_id=5001,\n        video_token_id=6000,\n        spatial_merge_size=2,\n        tokens_per_second=4.0,\n    )\n    input_ids = np.array(\n        [[11, 12, 5000, 5001, 21, 22, 23, 24, 31, 32]],\n        dtype=np.int64,\n    )\n    attention_mask = np.ones_like(input_ids, dtype=np.int64)\n    image_grid_thw = np.array([[1, 4, 4]], dtype=np.int64)\n    position_ids, deltas = get_mrope_position_ids(\n        input_ids,\n        meta,\n        attention_mask=attention_mask,\n        image_grid_thw=image_grid_thw,\n        video_grid_thw=None,\n        second_per_grid_ts=None,\n    )\n    expected = np.array(\n        [\n            [0, 1, 2, 3, 3, 3, 3, 5, 6, 7],\n            [0, 1, 2, 3, 3, 4, 4, 5, 6, 7],\n            [0, 1, 2, 3, 4, 3, 4, 5, 6, 7],\n        ],\n        dtype=np.int64,\n    ).reshape(3, 1, -1)\n    np.testing.assert_array_equal(position_ids, expected)\n    np.testing.assert_array_equal(deltas, np.array([[-2]], dtype=np.int64))\n\n\ndef test_apply_mrope_accepts_3_batch_seq_layout():\n    theta = 10000.0\n    mrope_section = (2, 2, 2)\n    batch, seq_len, heads, head_dim = 1, 4, 2, 12\n    rng = np.random.default_rng(1)\n    q_np = rng.standard_normal((batch, seq_len, heads, head_dim), dtype=np.float32)\n    k_np = rng.standard_normal((batch, seq_len, heads, head_dim), dtype=np.float32)\n\n    position_ids_bsc = np.zeros((batch, seq_len, 3), dtype=np.int64)\n    position_ids_bsc[0, :, 0] = np.arange(seq_len)\n    position_ids_bsc[0, :, 1] = np.arange(seq_len) * 2\n    position_ids_bsc[0, :, 2] = np.arange(seq_len) * 3\n    position_ids_3bs = np.transpose(position_ids_bsc, (2, 0, 1))\n\n    mlc_q_bsc, mlc_k_bsc = _run_mlc_mrope(q_np, k_np, position_ids_bsc, theta, mrope_section)\n    mlc_q_3bs, mlc_k_3bs = _run_mlc_mrope(q_np, k_np, position_ids_3bs, theta, mrope_section)\n    ref_q, ref_k = _numpy_apply_mrope(q_np, k_np, position_ids_bsc, theta, mrope_section)\n\n    np.testing.assert_allclose(mlc_q_bsc, ref_q, rtol=1e-5, atol=1e-5)\n    np.testing.assert_allclose(mlc_k_bsc, ref_k, rtol=1e-5, atol=1e-5)\n    np.testing.assert_allclose(mlc_q_3bs, ref_q, rtol=1e-5, atol=1e-5)\n    np.testing.assert_allclose(mlc_k_3bs, ref_k, rtol=1e-5, atol=1e-5)\n\n\ndef test_get_mrope_position_ids_output_is_directly_usable():\n    theta = 10000.0\n    mrope_section = (2, 2, 2)\n    meta = VisionPositionMetadata(\n        vision_start_token_id=7000,\n        image_token_id=7001,\n        video_token_id=7002,\n        spatial_merge_size=2,\n        tokens_per_second=4.0,\n    )\n    input_ids = np.array([[11, 12, 7000, 7001, 21, 22, 23, 24, 31, 32]], dtype=np.int64)\n    attention_mask = np.ones_like(input_ids, dtype=np.int64)\n    image_grid_thw = np.array([[1, 4, 4]], dtype=np.int64)\n    position_ids_3bs, _ = get_mrope_position_ids(\n        input_ids,\n        meta,\n        attention_mask=attention_mask,\n        image_grid_thw=image_grid_thw,\n        video_grid_thw=None,\n        second_per_grid_ts=None,\n    )\n    position_ids_bsc = np.transpose(position_ids_3bs, (1, 2, 0))\n\n    batch, seq_len = input_ids.shape\n    heads, head_dim = 2, 12\n    rng = np.random.default_rng(2)\n    q_np = rng.standard_normal((batch, seq_len, heads, head_dim), dtype=np.float32)\n    k_np = rng.standard_normal((batch, seq_len, heads, head_dim), dtype=np.float32)\n\n    mlc_q_3bs, mlc_k_3bs = _run_mlc_mrope(q_np, k_np, position_ids_3bs, theta, mrope_section)\n    mlc_q_bsc, mlc_k_bsc = _run_mlc_mrope(q_np, k_np, position_ids_bsc, theta, mrope_section)\n\n    np.testing.assert_allclose(mlc_q_3bs, mlc_q_bsc, rtol=1e-5, atol=1e-5)\n    np.testing.assert_allclose(mlc_k_3bs, mlc_k_bsc, rtol=1e-5, atol=1e-5)\n"
  },
  {
    "path": "tests/python/op/test_top_p_pivot.py",
    "content": "import numpy as np\nimport pytest\nimport tvm\nimport tvm.testing\n\nfrom mlc_llm.op.top_p_pivot import top_p_pivot, top_p_renorm\n\n# mypy: disable-error-code=\"var-annotated\"\n\n# test category \"op_correctness\"\npytestmark = [pytest.mark.op_correctness]\n\n\n@pytest.mark.parametrize(\"batch_size\", [32, 64])\n@pytest.mark.parametrize(\"vocab\", [3, 32, 64, 128])\ndef test_top_p_renorm(batch_size, vocab):\n    top_p = 0.95\n    init_pivots_np = np.array([1 - top_p, 0.02, 0.01]).astype(np.float32)\n    top_p_np = np.array([top_p]).astype(np.float32)\n\n    p_np = np.random.exponential(3, size=(batch_size, vocab)).astype(np.float32)\n    p_np /= np.sum(p_np, axis=-1, keepdims=True)\n    final_pivot_np = np.zeros(batch_size).astype(np.float32)\n    final_lsum_np = np.zeros(batch_size).astype(np.float32)\n\n    dev = tvm.cuda(0)\n    var_prob = tvm.runtime.tensor(p_np, dev)\n    var_init_pivots = tvm.runtime.tensor(init_pivots_np, dev)\n    top_p_global = tvm.runtime.tensor(top_p_np, dev)\n    var_final_pivot = tvm.runtime.tensor(final_pivot_np, dev)\n    var_final_lsum = tvm.runtime.tensor(final_lsum_np, dev)\n\n    kernel = top_p_pivot(init_pivots_np.shape[0])\n    mod = tvm.build(kernel, target=\"cuda\")\n    mod(var_prob, top_p_global, var_init_pivots, var_final_pivot, var_final_lsum)\n\n    final_pivot = var_final_pivot.asnumpy()\n    final_lsum = var_final_lsum.asnumpy()\n\n    renorm_np = p_np.copy()\n    var_renorm = tvm.runtime.tensor(renorm_np, dev)\n\n    kernel_renorm = top_p_renorm()\n    mod_renorm = tvm.build(kernel_renorm, target=\"cuda\")\n    mod_renorm(var_prob, var_final_pivot, var_final_lsum, var_renorm)\n\n    renorm = var_renorm.asnumpy()\n\n    def verify_pivot(probs: np.ndarray, pivot: float, lsum: float, renorm: np.ndarray):\n        sorted_probs = np.sort(probs, axis=-1)[::-1]\n        num_larger_than_pivot = np.sum(sorted_probs >= pivot)\n        filtered_sorted_probs = sorted_probs[:num_larger_than_pivot]\n        min_larger_than_pivot = min(filtered_sorted_probs)\n\n        sum_larger_than_pivot = np.sum(np.where(sorted_probs >= pivot, sorted_probs, 0))\n        sum_larger_than_pivot_exclude_min = np.sum(\n            np.where(filtered_sorted_probs != min_larger_than_pivot, filtered_sorted_probs, 0)\n        )\n\n        probs[probs < pivot] = 0\n        renorm_prob = probs / np.sum(probs, axis=-1, keepdims=True)\n        try:\n            assert sum_larger_than_pivot >= top_p\n            assert sum_larger_than_pivot_exclude_min < top_p\n            assert abs(lsum - sum_larger_than_pivot) < 1e-6\n            assert np.allclose(renorm, renorm_prob, atol=1e-6, rtol=1e-6)\n        except AssertionError:\n            print(\"Failed\")\n            print(\"probs:\", repr(probs))\n            print(\"pivot:\", pivot)\n            print(\"sorted_probs:\", sorted_probs)\n            print(\"num_larger_than_pivot:\", num_larger_than_pivot)\n            print(\"filtered_sorted_probs:\", filtered_sorted_probs)\n            print(\"min_larger_than_pivot:\", min_larger_than_pivot)\n            print(\"sum_larger_than_pivot:\", sum_larger_than_pivot)\n            print(\"sum_larger_than_pivot_exclude_min:\", sum_larger_than_pivot_exclude_min)\n            print(\"renom_prob:\", renorm_prob)\n            print(\"renorm:\", renorm)\n            raise\n\n    for i in range(batch_size):\n        verify_pivot(p_np[i], final_pivot[i], final_lsum[i], renorm[i])\n\n\nif __name__ == \"__main__\":\n    tvm.testing.main()\n"
  },
  {
    "path": "tests/python/op/test_tree_attn.py",
    "content": "import math\n\nimport numpy as np\nimport pytest\nimport tvm\nimport tvm.testing\nfrom tvm.relax.frontend.nn.llm import tree_attn\n\n# test category \"op_correctness\"\npytestmark = [pytest.mark.op_correctness]\n\n\n@pytest.mark.parametrize(\"nbatch\", [1, 4, 32])\n@pytest.mark.parametrize(\"h_q\", [8, 16])\n@pytest.mark.parametrize(\"h_kv\", [4, 8])\n@pytest.mark.parametrize(\"d\", [128])\n@pytest.mark.parametrize(\"rotary_mode\", [0, 1])\ndef test_tree_attn(nbatch, h_q, h_kv, d, rotary_mode):\n    np.random.seed(0)\n    np.set_printoptions(linewidth=10000)\n\n    def gen_chain(num_nodes):\n        mask = np.tril(np.ones((num_nodes, num_nodes)))\n        return num_nodes, list(mask.flatten()), np.arange(num_nodes)\n\n    def gen_full_binary_tree(height):\n        mask = list()\n        pos = list()\n        num_nodes = 2**height - 1\n        for i in range(num_nodes):\n            if i == 0:\n                mask_0 = [0] * num_nodes\n                mask_0[0] = 1\n                mask.append(mask_0)\n                pos.append(0)\n            else:\n                mask_i = mask[(i + 1) // 2 - 1].copy()\n                mask_i[i] = 1\n                mask.append(mask_i)\n                pos.append(pos[(i + 1) // 2 - 1] + 1)\n        return num_nodes, list(np.array(mask).flatten()), pos\n\n    ### Inputs\n    num_nodes = 0\n    m_list = list()\n    mn_list = list()\n    mask_list = list()\n    q_pos_list = list()\n\n    mn_list.append(0)\n\n    for _ in range(nbatch):\n        choice = np.random.choice(2, 1, p=[1, 0])\n        if choice == 0:\n            nodes_batch = np.random.randint(3, 32)\n            res = gen_chain(nodes_batch)\n            num_nodes += nodes_batch\n        else:\n            height = np.random.randint(2, 6)\n            res = gen_full_binary_tree(height)\n            num_nodes += 2**height - 1\n        m_list.append(res[0])\n        mn_list.append(res[0] ** 2)\n        mask_list.extend(res[1])\n        q_pos_list.extend(res[2])\n\n    qkv_indptr = np.array(np.cumsum([0] + m_list)).astype(np.int32)\n    m_list = np.array(m_list).astype(np.int32)\n    mn_list = np.array(mn_list).astype(np.int32)\n    mn_list = np.cumsum(mn_list).astype(np.int32)\n    mask_list = np.array(mask_list).astype(np.int32)\n    q_pos_list = np.array(q_pos_list).astype(np.int32)\n\n    # print(\"qkv_indptr:\", qkv_indptr)\n    # print(\"m_list:\", m_list)\n    # print(\"mn_list:\", mn_list)\n    # for num_nodes, base in zip(m_list, mn_list):\n    #     print(\"num_nodes:\", num_nodes)\n    #     print(\"indptr:\", base)\n    #     print(\n    #         \"mask:\",\n    #         mask_list[base : base + num_nodes * num_nodes].reshape(num_nodes, num_nodes),\n    #     )\n    #     print(\"q_pos:\", q_pos_list[base : base + num_nodes])\n\n    q = np.random.rand(num_nodes, h_q, d).astype(np.float16)\n    q_indptr = qkv_indptr\n    k = np.random.rand(num_nodes, h_kv, d).astype(np.float16)\n    v = np.random.rand(num_nodes, h_kv, d).astype(np.float16)\n    kv_indptr = qkv_indptr\n    q_rope_position = q_pos_list\n    m_arr = m_list\n    mn_indptr = mn_list\n    mask = mask_list\n    output = np.zeros((num_nodes, h_q, d), dtype=np.float16)\n    lse = np.zeros((num_nodes, h_q), dtype=np.float32)\n    rotary_scale = 1.0\n    rotary_theta = 10000.0\n    attn_score_scaling_factor = 1.0\n\n    ### TVM Inputs\n    dev = tvm.cuda(0)\n    q_tvm = tvm.runtime.tensor(q, dev)\n    q_indptr_tvm = tvm.runtime.tensor(q_indptr, dev)\n    k_tvm = tvm.runtime.tensor(k, dev)\n    v_tvm = tvm.runtime.tensor(v, dev)\n    kv_indptr_tvm = tvm.runtime.tensor(kv_indptr, dev)\n    q_rope_position_tvm = tvm.runtime.tensor(q_rope_position, dev)\n    # m_arr_tvm = tvm.runtime.tensor(m_arr, dev)\n    mn_indptr_tvm = tvm.runtime.tensor(mn_indptr, dev)\n    mask_tvm = tvm.runtime.tensor(mask, dev)\n    output_tvm = tvm.runtime.tensor(output, dev)\n    lse_tvm = tvm.runtime.tensor(lse, dev)\n\n    target = tvm.target.Target(\"cuda\")\n    kernel = tree_attn(h_kv=h_kv, h_q=h_q, d=d, dtype=\"float16\", rope_scaling={}, target=target)\n    mod = tvm.build(kernel, target=target)\n    mod(\n        q_tvm,\n        q_indptr_tvm,\n        k_tvm,\n        v_tvm,\n        kv_indptr_tvm,\n        q_rope_position_tvm,\n        # m_arr_tvm,\n        mn_indptr_tvm,\n        mask_tvm,\n        output_tvm,\n        lse_tvm,\n        rotary_mode,\n        rotary_scale,\n        rotary_theta,\n        attn_score_scaling_factor,\n        nbatch,\n    )\n\n    ### Numpy reference\n    def numpy_reference(\n        q,\n        q_indptr,\n        k,\n        v,\n        kv_indptr,\n        q_rope_position,\n        m_arr,\n        mn_indptr,\n        mask,\n        rotary_mode,\n        rotary_scale,\n        rotary_theta,\n        attn_score_scaling_factor,\n        output_tvm,\n    ):\n        def rope_freq(s, d, d_range, theta, dtype):\n            freq = s / math.pow(theta, (d * 2 % d_range) / float(d_range))\n            cos_freq = np.cos(freq).astype(dtype)\n            sin_freq = np.sin(freq).astype(dtype)\n            return cos_freq, sin_freq\n\n        def rope(buffer, offset, rotary_dim, theta, scale, dtype):\n            result = buffer.copy()\n            for l, h, d in np.ndindex(buffer.shape):\n                cos_freq, sin_freq = rope_freq(offset[l] * scale, d, rotary_dim, theta, dtype)\n                cos = cos_freq * buffer[l, h, d]\n                sin = sin_freq * (\n                    -buffer[l, h, d + rotary_dim // 2]\n                    if d < rotary_dim // 2\n                    else buffer[l, h, d - rotary_dim // 2]\n                )\n                result[l, h, d] = cos + sin\n            return result\n\n        for i in range(len(m_arr)):\n            num_nodes = m_arr[i]\n            base = mn_indptr[i]\n            q_base = q_indptr[i]\n            kv_base = kv_indptr[i]\n            q_pos = q_rope_position[q_base : q_base + num_nodes]  # (num_nodes,)\n            q_i = q[q_base : q_base + num_nodes]  # (num_nodes, h_q, d)\n            k_i = k[kv_base : kv_base + num_nodes]  # (num_nodes, h_kv, d)\n            v_i = v[kv_base : kv_base + num_nodes]  # (num_nodes, h_kv, d)\n            mask_i = mask[base : base + num_nodes * num_nodes].reshape(num_nodes, num_nodes)\n\n            if rotary_mode == 1:\n                q_i = rope(q_i, q_pos, d, rotary_theta, rotary_scale, q_i.dtype)\n                k_i = rope(k_i, q_pos, d, rotary_theta, rotary_scale, k_i.dtype)\n\n            # group attention\n            # q: (num_nodes, h_q, d)\n            # k: (num_nodes, h_kv, d)\n            # v: (num_nodes, h_kv, d)\n            group_size = h_q // h_kv\n            q_reshape = q_i.transpose(1, 0, 2)  # (h_q, num_nodes, d)\n            k_reshape = k_i.transpose(1, 2, 0)  # (h_kv, d, num_nodes)\n            v_reshape = v_i.transpose(1, 0, 2)  # (h_kv, num_nodes, d)\n            # expand k_reshape\n            k_reshape = k_reshape.reshape(h_kv, 1, d, num_nodes)\n            k_reshape = np.repeat(k_reshape, group_size, axis=1)\n            k_reshape = k_reshape.reshape(h_q, d, num_nodes)\n            # expand v_reshape\n            v_reshape = v_reshape.reshape(h_kv, 1, num_nodes, d)\n            v_reshape = np.repeat(v_reshape, group_size, axis=1)\n            v_reshape = v_reshape.reshape(h_q, num_nodes, d)\n            # print(\"q_reshape:\", q_reshape.shape)\n            # print(\"k_reshape:\", k_reshape.shape)\n            # print(\"v_reshape:\", v_reshape.shape)\n\n            # qk: (h_q, num_nodes, num_nodes)\n            qk = np.matmul(q_reshape, k_reshape) * attn_score_scaling_factor / math.sqrt(float(d))\n            # softmax(qk, axis=-1), numerical stability\n            qk[:, mask_i == 0] = -np.inf\n            qk_max = np.max(qk, axis=-1, keepdims=True)\n            qk = np.exp(qk - qk_max)\n            qk = qk / np.sum(qk, axis=-1, keepdims=True)\n\n            # attention\n            output_i = np.matmul(qk, v_reshape).transpose(1, 0, 2)  # (num_nodes, h_q, d)\n            # print(output_i)\n\n            tvm.testing.assert_allclose(\n                output_i, output_tvm[q_base : q_base + num_nodes], rtol=1e-3, atol=1e-3\n            )\n\n    numpy_reference(\n        q,\n        q_indptr,\n        k,\n        v,\n        kv_indptr,\n        q_rope_position,\n        m_arr,\n        mn_indptr,\n        mask,\n        rotary_mode,\n        rotary_scale,\n        rotary_theta,\n        attn_score_scaling_factor,\n        output_tvm.numpy(),\n    )\n\n\nif __name__ == \"__main__\":\n    tvm.testing.main()\n"
  },
  {
    "path": "tests/python/op/test_two_stage_softmax.py",
    "content": "import numpy as np\nimport pytest\nimport scipy.special\nimport tvm\nfrom tvm.s_tir import dlight\n\n# test category \"op_correctness\"\npytestmark = [pytest.mark.op_correctness]\n\n\ndef test_two_stage_softmax():\n    from mlc_llm.compiler_pass.rewrite_softmax import _get_lse_and_softmax_func\n\n    chunk_size = 4096\n    target = tvm.target.Target(\"cuda\")\n    f_chunk_lse, f_softmax_with_lse = _get_lse_and_softmax_func(target, chunk_size)\n    mod = tvm.IRModule({\"chunk_lse\": f_chunk_lse, \"softmax_with_chunked_lse\": f_softmax_with_lse})\n    with target:\n        mod = dlight.ApplyDefaultSchedule(dlight.gpu.GeneralReduction())(mod)\n\n    runtime_mod = tvm.build(mod, target=target)\n    device = tvm.cuda()\n\n    num_runs = 5\n    vocab_size = 128256\n    for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]:\n        for _ in range(num_runs):\n            x_np = np.random.uniform(low=-10, high=10, size=(batch_size, vocab_size)).astype(\n                \"float32\"\n            )\n            y_np = scipy.special.softmax(x_np, axis=-1)\n\n            x_nd = tvm.runtime.tensor(x_np, device=device)\n            r_nd = tvm.runtime.empty(\n                (batch_size, (vocab_size + chunk_size - 1) // chunk_size),\n                x_np.dtype,\n                device=device,\n            )\n            y_nd = tvm.runtime.empty(x_np.shape, x_np.dtype, device=device)\n\n            runtime_mod[\"chunk_lse\"](x_nd, r_nd)\n            runtime_mod[\"softmax_with_chunked_lse\"](x_nd, r_nd, y_nd)\n\n            y_nd_arr = y_nd.numpy()\n            np.testing.assert_allclose(y_nd_arr, y_np, atol=1e-6, rtol=1e-6)\n\n        print(f\"pass batch size {batch_size}\")\n\n\nif __name__ == \"__main__\":\n    test_two_stage_softmax()\n"
  },
  {
    "path": "tests/python/quantization/test_awq_quantization.py",
    "content": "# pylint: disable=invalid-name,missing-docstring\nfrom typing import List\n\nimport numpy as np\nimport pytest\nimport torch\nimport tvm\nimport tvm.testing\nfrom tvm import DataType\nfrom tvm.relax.frontend import nn\n\nfrom mlc_llm.loader import QuantizeMapping\nfrom mlc_llm.quantization import QUANTIZATION, AWQQuantize\n\n\ndef dequantize_np(\n    config: AWQQuantize,\n    weight: np.ndarray,\n    zeros: np.ndarray,\n    scale: np.ndarray,\n) -> np.ndarray:\n    def decode_int_arr(int_arr: np.ndarray, num_elem_per_storage: int, bits: int):\n        bin_mask = (1 << bits) - 1\n        int_arr_repeated = np.repeat(int_arr, num_elem_per_storage, axis=-1)\n        indice_j = np.indices(int_arr_repeated.shape)[1]\n        arr_bin = np.bitwise_and(\n            np.right_shift(\n                int_arr_repeated,\n                (indice_j % num_elem_per_storage) * bits,\n            ),\n            bin_mask,\n        )\n        return arr_bin\n\n    weight_bin = decode_int_arr(\n        weight, config.num_elem_per_storage, DataType(config.quantize_dtype).bits\n    )\n    zero_bin = decode_int_arr(\n        zeros, config.num_elem_per_storage, DataType(config.quantize_dtype).bits\n    )\n    scale_repeated = np.repeat(scale, config.group_size, axis=-1)\n    zero_bin_repeated = np.repeat(zero_bin, config.group_size, axis=-1)\n    return (weight_bin - zero_bin_repeated) * scale_repeated\n\n\n@pytest.mark.parametrize(\n    \"quant_name, shape, dtype\",\n    [\n        (\"q4f16_awq\", [2, 4096], \"float16\"),\n    ],\n)\ndef test_dequantize_weight(quant_name: str, shape: List[int], dtype: str):\n    class Test(nn.Module):\n        def __init__(self) -> None:\n            super().__init__()\n            self.linear = nn.Linear(shape[1], shape[0], bias=False, dtype=dtype)\n\n        def forward(self, x: nn.Tensor):\n            return self.linear(x)\n\n    config = QUANTIZATION[quant_name]\n    assert isinstance(config, AWQQuantize)\n    weight_np = np.random.randint(\n        np.iinfo(config.storage_dtype).min,\n        np.iinfo(config.storage_dtype).max,\n        (shape[0], shape[1] // config.num_elem_per_storage),\n    ).astype(config.storage_dtype)\n    zeros_np = np.random.randint(\n        np.iinfo(config.storage_dtype).min,\n        np.iinfo(config.storage_dtype).max,\n        (shape[0], shape[1] // config.num_elem_per_storage // config.group_size),\n    ).astype(config.storage_dtype)\n    scale_np = np.random.random((shape[0], shape[1] // config.group_size)).astype(\n        config.model_dtype\n    )\n    mod = config.quantize_model(Test(), QuantizeMapping({}, {}), \"\")\n    mod.linear.qweight.data = weight_np\n    mod.linear.qzeros.data = zeros_np\n    mod.linear.scales.data = scale_np\n    model = mod.jit(spec={\"forward\": {\"x\": nn.spec.Tensor((shape[1], shape[1]), dtype)}})\n    out = model[\"forward\"](\n        torch.from_numpy(np.diag(np.ones(shape[1]).astype(dtype)))  # pylint: disable=no-member\n    )\n    ref = dequantize_np(config, weight_np, zeros_np, scale_np).T\n    tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)\n\n\nif __name__ == \"__main__\":\n    test_dequantize_weight(\"q4f16_awq\", [2, 4096], \"float16\")\n"
  },
  {
    "path": "tests/python/quantization/test_group_quantization.py",
    "content": "# pylint: disable=invalid-name,missing-docstring\nfrom typing import List\n\nimport numpy as np\nimport pytest\nimport torch\nimport tvm\nimport tvm.testing\nfrom tvm import DataType\nfrom tvm.relax.frontend import nn\n\nfrom mlc_llm.loader import QuantizeMapping\nfrom mlc_llm.quantization import QUANTIZATION\nfrom mlc_llm.quantization.group_quantization import (\n    GroupQuantize,\n    GroupQuantizeEmbedding,\n    GroupQuantizeLinear,\n)\n\n\ndef quantize_np(config: GroupQuantize, weight: np.ndarray):\n    n, k = weight.shape\n    weight_padded = np.pad(\n        weight,\n        ((0, 0), (0, (config.group_size - k % config.group_size) % config.group_size)),\n    )\n    n, k = weight_padded.shape\n    weight_reshaped = np.reshape(weight_padded, (n, k // config.group_size, config.group_size))\n    max_abs = np.maximum(np.max(np.abs(weight_reshaped), axis=-1), 1e-4)\n    scale = np.divide(max_abs, config.max_int_value)\n    scale_reshaped = np.reshape(scale, (*scale.shape, 1))\n    weight_scaled_reshaped = np.clip(\n        np.add(\n            np.round(np.divide(weight_reshaped, scale_reshaped)),\n            config.max_int_value,\n        ),\n        0,\n        config.max_int_value * 2,\n    ).astype(config.storage_dtype)\n    weight_filtered = np.reshape(weight_scaled_reshaped, (n, k))\n    weight_filtered[..., weight.shape[1] :] = 0\n    weight_scaled = np.reshape(\n        weight_filtered,\n        (n, k // config.num_elem_per_storage, config.num_elem_per_storage),\n    )\n    indice_k = np.indices(weight_scaled.shape, dtype=config.storage_dtype)[-1]\n    quantized_weight = np.sum(\n        np.left_shift(weight_scaled, indice_k * DataType(config.quantize_dtype).bits),\n        axis=-1,\n        dtype=config.storage_dtype,\n    )\n    return quantized_weight, scale\n\n\ndef dequantize_np(\n    config: GroupQuantize,\n    weight: np.ndarray,\n    scale: np.ndarray,\n    out_shape: List[int] = None,\n):\n    assert weight.shape[0] == scale.shape[0]\n    bin_mask = (1 << DataType(config.quantize_dtype).bits) - 1\n    max_int = config.max_int_value\n    out_shape = (\n        [weight.shape[0], weight.shape[1] * config.num_elem_per_storage]\n        if out_shape is None\n        else out_shape\n    )\n    weight_repeated = np.repeat(weight, config.num_elem_per_storage, axis=-1)\n    scale_repeated = np.repeat(scale, config.group_size, axis=-1)\n    indice_j = np.indices(weight_repeated.shape)[1]\n    weight_bin = np.bitwise_and(\n        np.right_shift(\n            weight_repeated,\n            (indice_j % config.num_elem_per_storage) * DataType(config.quantize_dtype).bits,\n        ),\n        bin_mask,\n    )\n    assert weight_bin.shape[1] <= scale_repeated.shape[1]\n    return ((weight_bin - max_int) * scale_repeated[..., : weight_bin.shape[1]])[\n        : out_shape[0], : out_shape[1]\n    ]\n\n\n@pytest.mark.parametrize(\n    \"quant_name, shape, dtype, device\",\n    [\n        (\"q3f16_1\", [2, 13], \"float16\", \"cpu\"),\n        (\"q3f16_1\", [16, 120], \"float16\", \"cpu\"),\n        (\"q4f16_1\", [2, 13], \"float16\", \"cpu\"),\n        (\"q4f16_1\", [16, 128], \"float16\", \"cpu\"),\n        (\"q4f32_1\", [2, 13], \"float32\", \"cpu\"),\n        (\"q4f32_1\", [16, 128], \"float32\", \"cpu\"),\n    ],\n)\ndef test_quantize_weight(quant_name: str, shape: List[int], dtype: str, device: str):\n    config = QUANTIZATION[quant_name]\n    assert isinstance(config, GroupQuantize)\n    weight_np = np.random.random(shape).astype(dtype)\n    output = config.quantize_weight(tvm.runtime.tensor(weight_np, device=tvm.device(device)))\n    quantized_weight, scale = output[0].numpy(), output[1].numpy()\n    quantized_weight_ref, scale_ref = quantize_np(config, weight_np)\n    tvm.testing.assert_allclose(scale, scale_ref, rtol=1e-3, atol=1e-3)\n    tvm.testing.assert_allclose(\n        dequantize_np(config, quantized_weight, scale, shape),\n        dequantize_np(config, quantized_weight_ref, scale_ref, shape),\n        rtol=1e-2 if quant_name.startswith(\"q3\") else 1e-3,\n        atol=0.4 if quant_name.startswith(\"q3\") else 0.2,\n    )\n\n\n@pytest.mark.parametrize(\n    \"quant_name, shape, dtype\",\n    [\n        (\"q3f16_1\", [2, 13], \"float16\"),\n        (\"q3f16_1\", [16, 120], \"float16\"),\n        (\"q4f16_1\", [2, 13], \"float16\"),\n        (\"q4f16_1\", [16, 128], \"float16\"),\n        (\"q4f32_1\", [2, 13], \"float32\"),\n        (\"q4f32_1\", [16, 128], \"float32\"),\n    ],\n)\ndef test_dequantize_weight(quant_name: str, shape: List[int], dtype: str):\n    class Test(nn.Module):\n        def __init__(self) -> None:\n            super().__init__()\n            self.linear = nn.Linear(shape[1], shape[0], bias=False, dtype=dtype)\n\n        def forward(self, x: nn.Tensor):\n            return self.linear(x)\n\n    config = QUANTIZATION[quant_name]\n    assert isinstance(config, GroupQuantize)\n    num_group = -(shape[1] // -config.group_size)\n    weight_np = np.random.randint(\n        np.iinfo(config.storage_dtype).min,\n        np.iinfo(config.storage_dtype).max,\n        (shape[0], config.num_storage_per_group * num_group),\n    ).astype(config.storage_dtype)\n    scale_np = np.random.random((shape[0], num_group)).astype(config.model_dtype)\n    mod = config.quantize_model(Test(), QuantizeMapping({}, {}), \"\")\n    mod.linear.q_weight.data = weight_np\n    mod.linear.q_scale.data = scale_np\n    model = mod.jit(spec={\"forward\": {\"x\": nn.spec.Tensor((shape[1], shape[1]), dtype)}})\n    out = model[\"forward\"](\n        torch.from_numpy(np.diag(np.ones(shape[1]).astype(dtype)))  # pylint: disable=no-member\n    )\n    ref = dequantize_np(config, weight_np, scale_np, shape).T\n    tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)\n\n\n@pytest.mark.parametrize(\n    \"quant_name, shape, dtype\",\n    [\n        (\"q3f16_1\", [16, 128], \"float16\"),\n        (\"q4f16_1\", [16, 128], \"float16\"),\n        (\"q4f32_1\", [16, 128], \"float32\"),\n    ],\n)\ndef test_quantize_model(quant_name: str, shape: List[int], dtype: str):\n    class Test(nn.Module):\n        def __init__(self) -> None:\n            super().__init__()\n            self.linear = nn.Linear(shape[0], shape[1], dtype=dtype)\n            self.embedding = nn.Embedding(shape[0], shape[1], dtype=dtype)\n\n        def forward(self, x: nn.Tensor):\n            return self.linear(x)\n\n    config = QUANTIZATION[quant_name]\n    assert isinstance(config, GroupQuantize)\n    quant_map = QuantizeMapping({}, {})\n    mod = config.quantize_model(Test(), quant_map, \"model\")\n    assert quant_map.param_map[\"model.linear.weight\"] == [\n        \"model.linear.q_weight\",\n        \"model.linear.q_scale\",\n    ]\n    assert quant_map.map_func[\"model.linear.weight\"] == config.quantize_weight\n    assert isinstance(mod.linear, GroupQuantizeLinear)\n    assert quant_map.param_map[\"model.embedding.weight\"] == [\n        \"model.embedding.q_weight\",\n        \"model.embedding.q_scale\",\n    ]\n    assert quant_map.map_func[\"model.embedding.weight\"] == config.quantize_weight\n    assert isinstance(mod.embedding, GroupQuantizeEmbedding)\n\n\nif __name__ == \"__main__\":\n    test_quantize_weight(\"q4f16_1\", [16, 128], \"float16\", \"llvm\")\n    test_quantize_model(\"q4f16_1\", [16, 128], \"float16\")\n    test_dequantize_weight(\"q4f16_1\", [16, 128], \"float16\")\n"
  },
  {
    "path": "tests/python/router/test_router.py",
    "content": "import asyncio\n\nfrom mlc_llm.protocol import openai_api_protocol\nfrom mlc_llm.router import Router\n\nmodel_tp1 = \"./dist/Llama-3.2-1B-Instruct-q0f16-MLC/\"\nmodel_lib_tp1 = \"./dist/lib/Llama-3.2-1B-q0f16-cuda.so\"\n# model_lib_tp1 = None\n\nmodel_tp2 = \"./dist/Llama-3.2-1B-Instruct-q0f16-MLC-tp2/\"\nmodel_lib_tp2 = \"./dist/lib/Llama-3.2-1B-q0f16-cuda-tp2.so\"\n# model_lib_tp2 = None\n\n\ndef get_router_1tp1():\n    return (\n        Router(\n            model_tp1,\n            model_lib=model_lib_tp1,\n            hosts=[\"127.0.0.1\"],\n            ports=[8080],\n        ),\n        model_tp1,\n    )\n\n\ndef get_router_2tp1():\n    return (\n        Router(\n            model_tp1,\n            model_lib=model_lib_tp1,\n            hosts=[\"127.0.0.1\", \"127.0.0.1\"],\n            ports=[8080, 8081],\n            device_id_starts=[0, 1],\n            npes=2,\n        ),\n        model_tp1,\n    )\n\n\ndef get_router_1tp2():\n    return (\n        Router(\n            model_tp2,\n            model_lib=model_lib_tp2,\n            hosts=[\"127.0.0.1\"],\n            ports=[8080],\n            npes=2,\n        ),\n        model_tp2,\n    )\n\n\ndef get_router_2tp2():\n    return (\n        Router(\n            model_tp2,\n            model_lib=model_lib_tp2,\n            hosts=[\"127.0.0.1\", \"127.0.0.1\"],\n            ports=[8080, 8081],\n            device_id_starts=[0, 2],\n            npes=4,\n        ),\n        model_tp2,\n    )\n\n\nCONFIG_TO_ROUTER = {\n    \"1tp1\": get_router_1tp1,\n    \"2tp1\": get_router_2tp1,\n    \"1tp2\": get_router_1tp2,\n    \"2tp2\": get_router_2tp2,\n}\n\n\nasync def test_router(schedule: str = \"round_robin\", endpoints_config: str = \"1tp1\"):\n    router, model_id = CONFIG_TO_ROUTER[endpoints_config]()\n\n    request = openai_api_protocol.CompletionRequest(\n        prompt=\"The meaning of life \",\n        model=model_id,\n        stream=True,\n        max_tokens=64,\n        stream_options=openai_api_protocol.StreamOptions(include_usage=True),\n    )\n    if schedule == \"round_robin\":\n        async for chunk in router._handle_completion_round_robin(request, \"1\"):\n            print(chunk)\n    elif schedule == \"disagg\":\n        async for chunk in router._handle_completion_disagg(request, \"1\"):\n            print(chunk)\n    else:\n        raise ValueError(f\"Unknown scheduling method: {schedule}\")\n    router.terminate()\n\n\nif __name__ == \"__main__\":\n    # asyncio.run(test_router(\"round_robin\", endpoints_config=\"1tp1\"))\n    # asyncio.run(test_router(\"round_robin\", endpoints_config=\"1tp2\"))\n    # asyncio.run(test_router(\"round_robin\", endpoints_config=\"2tp1\"))\n    asyncio.run(test_router(\"round_robin\", endpoints_config=\"2tp2\"))\n"
  },
  {
    "path": "tests/python/serve/evaluate_engine.py",
    "content": "# pylint: disable=line-too-long,missing-docstring\nimport argparse\nimport os\nimport random\nfrom typing import List, Tuple\n\nfrom mlc_llm.protocol.generation_config import GenerationConfig\nfrom mlc_llm.serve.sync_engine import EngineConfig, SyncMLCEngine\n\n\ndef _parse_args():\n    args = argparse.ArgumentParser()\n    args.add_argument(\"--model-lib\", type=str)\n    args.add_argument(\"--device\", type=str, default=\"auto\")\n    args.add_argument(\"--batch-size\", type=int, default=80)\n    args.add_argument(\"--max-total-seq-length\", type=int)\n    args.add_argument(\"--seed\", type=int, default=0)\n\n    parsed = args.parse_args()\n    parsed.model = os.path.dirname(parsed.model_lib)\n    assert parsed.batch_size % 16 == 0\n    return parsed\n\n\ndef generate_requests(\n    num_requests: int, input_length: int, output_length: int\n) -> Tuple[List[List[int]], List[GenerationConfig]]:\n    prompt_ids = []\n    for _ in range(num_requests):\n        token_ids = []\n        for _ in range(input_length):\n            token_ids.append(random.randint(0, 30000))\n        prompt_ids.append(token_ids)\n    generation_config_list = [\n        GenerationConfig(temperature=1.0, top_p=1.0, max_tokens=output_length)\n    ] * num_requests\n    return prompt_ids, generation_config_list\n\n\ndef benchmark(args: argparse.Namespace):\n    random.seed(args.seed)\n\n    # Create engine\n    engine = SyncMLCEngine(\n        model=args.model,\n        device=args.device,\n        model_lib=args.model_lib,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_num_sequence=args.batch_size,\n            max_total_sequence_length=args.max_total_seq_length,\n        ),\n    )\n\n    print(args)\n    for num_requests in [1, 2, 4, 8, 16, 32, 64]:\n        if num_requests > args.batch_size:\n            continue\n        for input_length in [64, 128, 256, 512, 1024]:\n            if num_requests * input_length >= 16384:\n                continue\n            for output_length in [4]:\n                print(f\"nreq={num_requests}\\tin={input_length}\\tout={output_length}\")\n                prompt_ids, generation_config = generate_requests(\n                    num_requests, input_length, output_length\n                )\n                engine.reset()\n                engine.generate(prompt_ids, generation_config)\n                print()\n\n\nif __name__ == \"__main__\":\n    ARGS = _parse_args()\n    benchmark(ARGS)\n"
  },
  {
    "path": "tests/python/serve/server/conftest.py",
    "content": "# pylint: disable=missing-module-docstring,missing-function-docstring\nimport os\nfrom typing import Tuple\n\nimport pytest\n\nfrom mlc_llm.serve import PopenServer\n\n\n@pytest.fixture(scope=\"session\")\ndef served_model() -> Tuple[str, str]:\n    model_lib = os.environ.get(\"MLC_SERVE_MODEL_LIB\")\n    if model_lib is None:\n        raise ValueError(\n            'Environment variable \"MLC_SERVE_MODEL_LIB\" not found. '\n            \"Please set it to model lib compiled by MLC LLM \"\n            \"(e.g., `dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so`).\"\n        )\n    model = os.path.dirname(model_lib)\n    return model, model_lib\n\n\n@pytest.fixture(scope=\"session\")\ndef launch_server(served_model):  # pylint: disable=redefined-outer-name\n    \"\"\"A pytest session-level fixture which launches the server in a subprocess.\"\"\"\n    server = PopenServer(\n        model=served_model[0],\n        model_lib=served_model[1],\n        enable_tracing=True,\n        enable_debug=True,\n        port=8000,\n    )\n\n    with server:\n        yield\n"
  },
  {
    "path": "tests/python/serve/server/test_embedding_server.py",
    "content": "\"\"\"Embedding server endpoint tests in MLC LLM.\n\nTests the /v1/embeddings endpoint via HTTP using the OpenAI client,\nfollowing the same patterns as test_server.py.\n\nReuses MLC LLM test infrastructure:\n  - Pytest markers (endpoint)\n  - expect_error() response validation pattern from test_server.py\n  - OpenAI client usage pattern from test_server.py\n  - Session-scoped server fixture pattern from conftest.py\n\nRun (launches its own embedding-only server):\n  MLC_SERVE_EMBEDDING_MODEL_LIB=\"path/to/model.dylib\" \\\n    pytest -m endpoint tests/python/serve/server/test_embedding_server.py -v\n\nEnvironment variables:\n  MLC_SERVE_EMBEDDING_MODEL_LIB  Path to compiled embedding model library (required)\n  MLC_SERVE_EMBEDDING_MODEL      Path to embedding model weight directory\n                                  (optional, defaults to dirname of model lib)\n\"\"\"\n\n# pylint: disable=redefined-outer-name\n\nimport json\nimport os\nimport signal\nimport subprocess\nimport sys\nimport time\nfrom pathlib import Path\nfrom typing import Dict, Optional\n\nimport numpy as np\nimport pytest\nimport requests\nfrom openai import OpenAI\n\n# Reuse MLC LLM marker system\npytestmark = [pytest.mark.endpoint]\n\n# ---------------------------------------------------------------------------\n# Config\n# ---------------------------------------------------------------------------\n\nEMBEDDING_MODEL_LIB = os.environ.get(\"MLC_SERVE_EMBEDDING_MODEL_LIB\")\nEMBEDDING_MODEL_DIR = os.environ.get(\n    \"MLC_SERVE_EMBEDDING_MODEL\",\n    os.path.dirname(EMBEDDING_MODEL_LIB) if EMBEDDING_MODEL_LIB else None,\n)\nEMBEDDING_SERVER_HOST = \"127.0.0.1\"\nEMBEDDING_SERVER_PORT = 8321\nEMBEDDING_BASE_URL = f\"http://{EMBEDDING_SERVER_HOST}:{EMBEDDING_SERVER_PORT}/v1\"\nEMBEDDING_MODEL_NAME = \"embedding\"\n\n\ndef _skip_if_no_model():\n    if EMBEDDING_MODEL_LIB is None:\n        pytest.skip(\n            'Environment variable \"MLC_SERVE_EMBEDDING_MODEL_LIB\" not found. '\n            \"Set it to a compiled embedding model library.\"\n        )\n    if not os.path.isfile(EMBEDDING_MODEL_LIB):\n        pytest.skip(f\"Embedding model library not found at: {EMBEDDING_MODEL_LIB}\")\n    if EMBEDDING_MODEL_DIR is None or not os.path.isdir(EMBEDDING_MODEL_DIR):\n        pytest.skip(f\"Embedding model directory not found at: {EMBEDDING_MODEL_DIR}\")\n\n\n# ---------------------------------------------------------------------------\n# Response validation helpers — adapted from test_server.py patterns\n# ---------------------------------------------------------------------------\n\n\ndef check_embedding_response(\n    response: Dict,\n    *,\n    model: str,\n    num_embeddings: int,\n    expected_dim: Optional[int] = None,\n    check_unit_norm: bool = True,\n):\n    \"\"\"Validate an OpenAI-compatible embedding response.\n\n    Adapted from check_openai_nonstream_response() in test_server.py,\n    specialized for embedding responses.\n    \"\"\"\n    assert response[\"object\"] == \"list\"\n    assert response[\"model\"] == model\n\n    data = response[\"data\"]\n    assert isinstance(data, list)\n    assert len(data) == num_embeddings\n\n    for item in data:\n        assert item[\"object\"] == \"embedding\"\n        assert isinstance(item[\"index\"], int)\n        emb = item[\"embedding\"]\n        assert isinstance(emb, list)\n        assert len(emb) > 0\n\n        if expected_dim is not None:\n            assert len(emb) == expected_dim, f\"Expected dim={expected_dim}, got {len(emb)}\"\n\n        if check_unit_norm:\n            norm = float(np.linalg.norm(emb))\n            assert abs(norm - 1.0) < 1e-3, f\"Expected unit norm, got {norm}\"\n\n    # Usage validation — same pattern as test_server.py\n    usage = response[\"usage\"]\n    assert isinstance(usage, dict)\n    assert usage[\"prompt_tokens\"] > 0\n    assert usage[\"total_tokens\"] == usage[\"prompt_tokens\"]\n\n\ndef expect_error(response_str: str, msg_prefix: Optional[str] = None):\n    \"\"\"Validate error response — reused directly from test_server.py.\"\"\"\n    response = json.loads(response_str)\n    assert response[\"object\"] == \"error\"\n    assert isinstance(response[\"message\"], str)\n    if msg_prefix is not None:\n        assert response[\"message\"].startswith(msg_prefix)\n\n\n# ---------------------------------------------------------------------------\n# Server fixture — follows PopenServer/launch_server pattern from conftest.py\n# ---------------------------------------------------------------------------\n\n\n@pytest.fixture(scope=\"module\")\ndef launch_embedding_server():\n    \"\"\"Launch an embedding-only server as a subprocess.\n\n    Follows the same lifecycle pattern as the launch_server fixture\n    in serve/server/conftest.py, but uses a lightweight embedding-only\n    server since PopenServer doesn't support --embedding-model yet.\n    \"\"\"\n    _skip_if_no_model()\n\n    mlc_llm_path = str(Path(__file__).resolve().parents[4] / \"python\")\n    server_code = f\"\"\"\nimport sys\nsys.path.insert(0, \"{mlc_llm_path}\")\n\nimport fastapi\nimport uvicorn\nfrom mlc_llm.serve.embedding_engine import AsyncEmbeddingEngine\nfrom mlc_llm.serve.server import ServerContext\nfrom mlc_llm.serve.entrypoints import openai_entrypoints\n\napp = fastapi.FastAPI()\napp.include_router(openai_entrypoints.app)\n\nengine = AsyncEmbeddingEngine(\n    model=\"{EMBEDDING_MODEL_DIR}\",\n    model_lib=\"{EMBEDDING_MODEL_LIB}\",\n    device=\"auto\",\n)\nctx = ServerContext()\nServerContext.server_context = ctx\nctx.add_embedding_engine(\"{EMBEDDING_MODEL_NAME}\", engine)\n\nuvicorn.run(app, host=\"{EMBEDDING_SERVER_HOST}\", port={EMBEDDING_SERVER_PORT}, log_level=\"info\")\n\"\"\"\n    with subprocess.Popen(\n        [sys.executable, \"-c\", server_code],\n        stdout=subprocess.PIPE,\n        stderr=subprocess.PIPE,\n    ) as proc:\n        # Wait for server readiness — same polling pattern as PopenServer.start()\n        timeout = 120\n        attempts = 0.0\n        ready = False\n        while attempts < timeout:\n            try:\n                response = requests.get(f\"{EMBEDDING_BASE_URL}/models\", timeout=2)\n                if response.status_code == 200:\n                    ready = True\n                    break\n            except requests.RequestException:\n                pass\n            attempts += 0.5\n            time.sleep(0.5)\n\n        if not ready:\n            stderr = proc.stderr.read().decode() if proc.stderr else \"\"\n            proc.kill()\n            raise RuntimeError(f\"Embedding server failed to start in {timeout}s.\\nStderr: {stderr}\")\n\n        yield proc\n\n        # Cleanup — same pattern as PopenServer.terminate()\n        proc.send_signal(signal.SIGINT)\n        try:\n            proc.wait(timeout=10)\n        except subprocess.TimeoutExpired:\n            proc.kill()\n\n\n@pytest.fixture(scope=\"module\")\ndef client(launch_embedding_server):\n    \"\"\"OpenAI client connected to the embedding server.\"\"\"\n    assert launch_embedding_server is not None\n    return OpenAI(base_url=EMBEDDING_BASE_URL, api_key=\"none\")\n\n\n# ===================================================================\n# /v1/models\n# ===================================================================\n\n\n@pytest.mark.usefixtures(\"client\")\ndef test_models_endpoint():\n    \"\"\"The /v1/models endpoint lists the embedding model.\"\"\"\n    resp = requests.get(f\"{EMBEDDING_BASE_URL}/models\", timeout=5)\n    assert resp.status_code == 200\n    data = resp.json()\n    assert isinstance(data[\"data\"], list)\n\n\n# ===================================================================\n# Single input\n# ===================================================================\n\n\ndef test_single_string_input(client):\n    \"\"\"Single string input returns one embedding.\"\"\"\n    resp = client.embeddings.create(input=\"What is machine learning?\", model=EMBEDDING_MODEL_NAME)\n    raw = resp.model_dump()\n    check_embedding_response(raw, model=EMBEDDING_MODEL_NAME, num_embeddings=1)\n\n\n# ===================================================================\n# Batch input\n# ===================================================================\n\nBATCH_INPUTS = [\n    \"What is machine learning?\",\n    \"How to brew coffee?\",\n    \"ML is a subset of AI.\",\n]\n\n\ndef test_batch_string_input(client):\n    \"\"\"List of strings returns one embedding per input.\"\"\"\n    resp = client.embeddings.create(input=BATCH_INPUTS, model=EMBEDDING_MODEL_NAME)\n    raw = resp.model_dump()\n    check_embedding_response(raw, model=EMBEDDING_MODEL_NAME, num_embeddings=len(BATCH_INPUTS))\n\n\ndef test_batch_index_ordering(client):\n    \"\"\"Embedding indices are sequential.\"\"\"\n    resp = client.embeddings.create(input=BATCH_INPUTS, model=EMBEDDING_MODEL_NAME)\n    indices = [d.index for d in resp.data]\n    assert indices == list(range(len(BATCH_INPUTS)))\n\n\n# ===================================================================\n# Cosine similarity — semantic quality via endpoint\n# ===================================================================\n\n\ndef test_cosine_similarity_via_endpoint(client):\n    \"\"\"Related texts have higher similarity than unrelated (end-to-end).\"\"\"\n    resp = client.embeddings.create(\n        input=[\n            \"What is machine learning?\",\n            \"Explain deep learning\",\n            \"Order a pizza\",\n        ],\n        model=EMBEDDING_MODEL_NAME,\n    )\n    e0, e1, e2 = [np.array(d.embedding) for d in resp.data]\n    sim_related = float(np.dot(e0, e1))\n    sim_unrelated = float(np.dot(e0, e2))\n    assert (\n        sim_related > sim_unrelated\n    ), f\"Related ({sim_related:.4f}) should > unrelated ({sim_unrelated:.4f})\"\n\n\n# ===================================================================\n# Dimension truncation (Matryoshka)\n# ===================================================================\n\n\ndef test_dimension_truncation(client):\n    \"\"\"dimensions parameter truncates and re-normalizes output.\"\"\"\n    target_dim = 256\n    resp = client.embeddings.create(\n        input=\"Hello world\", model=EMBEDDING_MODEL_NAME, dimensions=target_dim\n    )\n    raw = resp.model_dump()\n    check_embedding_response(\n        raw,\n        model=EMBEDDING_MODEL_NAME,\n        num_embeddings=1,\n        expected_dim=target_dim,\n    )\n\n\n# ===================================================================\n# Encoding format\n# ===================================================================\n\n\n@pytest.mark.usefixtures(\"launch_embedding_server\")\ndef test_base64_encoding():\n    \"\"\"base64 encoding format returns base64-encoded embeddings.\"\"\"\n    resp = requests.post(\n        f\"{EMBEDDING_BASE_URL}/embeddings\",\n        json={\n            \"input\": \"Hello world\",\n            \"model\": EMBEDDING_MODEL_NAME,\n            \"encoding_format\": \"base64\",\n        },\n        timeout=5,\n    )\n    assert resp.status_code == 200\n    data = resp.json()\n    assert data[\"data\"][0][\"object\"] == \"embedding\"\n    # base64 string should be a non-empty string (not a list)\n    emb = data[\"data\"][0][\"embedding\"]\n    assert isinstance(emb, str) and len(emb) > 0\n\n\n# ===================================================================\n# Error handling — reuses expect_error() pattern from test_server.py\n# ===================================================================\n\n\n@pytest.mark.usefixtures(\"launch_embedding_server\")\ndef test_any_model_name_works_with_single_engine():\n    \"\"\"When only one embedding engine is served, any model name works.\n\n    This mirrors ServerContext.get_engine() behavior: a single served\n    model is returned regardless of the requested model name.\n    \"\"\"\n    resp = requests.post(\n        f\"{EMBEDDING_BASE_URL}/embeddings\",\n        json={\"input\": \"test\", \"model\": \"any-name-works\"},\n        timeout=5,\n    )\n    assert resp.status_code == 200\n    data = resp.json()\n    assert len(data[\"data\"]) == 1\n\n\n# ===================================================================\n# Standalone runner (same pattern as test_server.py __main__)\n# ===================================================================\n\nif __name__ == \"__main__\":\n    _skip_if_no_model()\n\n    print(f\"Using model: {EMBEDDING_MODEL_DIR}\")\n    print(f\"Using model lib: {EMBEDDING_MODEL_LIB}\")\n    print(f\"Server URL: {EMBEDDING_BASE_URL}\")\n    print(\n        \"\\nMake sure the embedding server is running, or set env vars \"\n        \"and use pytest to auto-launch.\"\n    )\n\n    # Allow running against an already-running server\n    c = OpenAI(base_url=EMBEDDING_BASE_URL, api_key=\"none\")\n    test_models_endpoint()\n    test_single_string_input(c)\n    test_batch_string_input(c)\n    test_batch_index_ordering(c)\n    test_cosine_similarity_via_endpoint(c)\n    test_dimension_truncation(c)\n    test_base64_encoding()\n    test_any_model_name_works_with_single_engine()\n    print(\"\\nAll embedding server tests passed!\")\n"
  },
  {
    "path": "tests/python/serve/server/test_server.py",
    "content": "\"\"\"Server tests in MLC LLM.\nBefore running any test, we use pytest fixtures to launch a\ntest-session-wide server in a subprocess, and then execute the tests.\n\nThe recommended way to run the tests is to use the following command:\n  MLC_SERVE_MODEL_LIB=\"YOUR_MODEL_LIB\" pytest -vv tests/python/serve/server/test_server.py\n\nHere \"YOUR_MODEL_LIB\" is a compiled model library like\n`dist/Llama-2-7b-chat-hf-q4f16_1/Llama-2-7b-chat-hf-q4f16_1-cuda.so`,\nas long as the model is built with batching and embedding separation enabled.\n\nTo directly run the Python file (a.k.a., not using pytest), you need to\nlaunch the server in ahead before running this file. This can be done in\ntwo steps:\n- start a new shell session, run\n  python -m mlc_llm.serve.server --model \"YOUR_MODEL_LIB\"\n- start another shell session, run this file\n  MLC_SERVE_MODEL_LIB=\"YOUR_MODEL_LIB\" python tests/python/serve/server/test_server.py\n\"\"\"\n\n# pylint: disable=missing-function-docstring,too-many-arguments,too-many-locals,too-many-branches\nimport json\nimport os\nfrom http import HTTPStatus\nfrom typing import Dict, List, Optional, Tuple\n\nimport pytest\nimport regex\nimport requests\nfrom openai import OpenAI\nfrom pydantic import BaseModel\n\nfrom mlc_llm.protocol.openai_api_protocol import (\n    CHAT_COMPLETION_MAX_TOP_LOGPROBS,\n    COMPLETION_MAX_TOP_LOGPROBS,\n)\n\nOPENAI_BASE_URL = \"http://127.0.0.1:8000/v1\"\nOPENAI_V1_MODELS_URL = \"http://127.0.0.1:8000/v1/models\"\nOPENAI_V1_COMPLETION_URL = \"http://127.0.0.1:8000/v1/completions\"\nOPENAI_V1_CHAT_COMPLETION_URL = \"http://127.0.0.1:8000/v1/chat/completions\"\nDEBUG_DUMP_EVENT_TRACE_URL = \"http://127.0.0.1:8000/debug/dump_event_trace\"\nMETRICS_URL = \"http://127.0.0.1:8000/metrics\"\n\n\nJSON_TOKEN_PATTERN = (\n    r\"((-?(?:0|[1-9]\\d*))(\\.\\d+)?([eE][-+]?\\d+)?)|null|true|false|\"\n    r'(\"((\\\\[\"\\\\\\/bfnrt])|(\\\\u[0-9a-fA-F]{4})|[^\"\\\\\\x00-\\x1f])*\")'\n)\nJSON_TOKEN_RE = regex.compile(JSON_TOKEN_PATTERN)\n\n\ndef is_json(s: str) -> bool:\n    try:\n        json.loads(s)\n        return True\n    except json.JSONDecodeError:\n        return False\n\n\ndef is_json_prefix(s: str) -> bool:\n    try:\n        json.loads(s)\n        return True\n    except json.JSONDecodeError as e:\n        # If the JSON decoder reaches the end of s, it is a prefix of a JSON string.\n        if e.pos == len(s):\n            return True\n        # Since json.loads is token-based instead of char-based, there may remain half a token after\n        # the matching position.\n        # If the left part is a prefix of a valid JSON token, the output is also valid\n        regex_match = JSON_TOKEN_RE.fullmatch(s[e.pos :], partial=True)\n        return regex_match is not None\n\n\ndef check_openai_nonstream_response(\n    response: Dict,\n    *,\n    is_chat_completion: bool,\n    model: str,\n    object_str: str,\n    num_choices: int,\n    finish_reasons: List[str],\n    completion_tokens: Optional[int] = None,\n    echo_prompt: Optional[str] = None,\n    suffix: Optional[str] = None,\n    stop: Optional[List[str]] = None,\n    require_substr: Optional[List[str]] = None,\n    check_json_output: bool = False,\n):\n    assert response[\"model\"] == model\n    assert response[\"object\"] == object_str\n\n    choices = response[\"choices\"]\n    assert isinstance(choices, list)\n    assert len(choices) <= num_choices\n    texts: List[str] = [\"\" for _ in range(num_choices)]\n    for choice in choices:\n        idx = choice[\"index\"]\n        assert choice[\"finish_reason\"] in finish_reasons\n\n        if not is_chat_completion:\n            assert isinstance(choice[\"text\"], str)\n            texts[idx] = choice[\"text\"]\n            if echo_prompt is not None:\n                assert texts[idx]\n            if suffix is not None:\n                assert texts[idx]\n        else:\n            message = choice[\"message\"]\n            assert message[\"role\"] == \"assistant\"\n            assert isinstance(message[\"content\"], str)\n            texts[idx] = message[\"content\"]\n\n        if stop is not None:\n            for stop_str in stop:\n                assert stop_str not in texts[idx]\n        if require_substr is not None:\n            for substr in require_substr:\n                assert substr in texts[idx]\n        if check_json_output:\n            # the output should be json or a prefix of a json string\n            # if the output is a prefix of a json string, the output must exceed the max output\n            # length\n            output_is_json = is_json(texts[idx])\n            output_is_json_prefix = is_json_prefix(texts[idx])\n            assert output_is_json or output_is_json_prefix\n            if not output_is_json and output_is_json_prefix:\n                assert choice[\"finish_reason\"] == \"length\"\n\n    usage = response[\"usage\"]\n    if usage is not None:\n        assert isinstance(usage, dict)\n        assert usage[\"total_tokens\"] == usage[\"prompt_tokens\"] + usage[\"completion_tokens\"]\n        assert usage[\"prompt_tokens\"] > 0\n        if completion_tokens is not None:\n            assert usage[\"completion_tokens\"] == completion_tokens\n\n\ndef check_openai_stream_response(\n    responses: List[Dict],\n    *,\n    is_chat_completion: bool,\n    model: str,\n    object_str: str,\n    num_choices: int,\n    finish_reasons: List[str],\n    completion_tokens: Optional[int] = None,\n    echo_prompt: Optional[str] = None,\n    suffix: Optional[str] = None,\n    stop: Optional[List[str]] = None,\n    require_substr: Optional[List[str]] = None,\n    check_json_output: bool = False,\n):\n    assert len(responses) > 0\n\n    finished = [False for _ in range(num_choices)]\n    outputs = [\"\" for _ in range(num_choices)]\n    finish_reason_list = [\"\" for _ in range(num_choices)]\n    for response in responses:\n        assert response[\"model\"] == model\n        assert response[\"object\"] == object_str\n\n        choices = response[\"choices\"]\n        assert isinstance(choices, list)\n        assert len(choices) <= num_choices\n        for choice in choices:\n            idx = choice[\"index\"]\n\n            if not is_chat_completion:\n                assert isinstance(choice[\"text\"], str)\n                outputs[idx] += choice[\"text\"]\n            else:\n                delta = choice[\"delta\"]\n                assert delta[\"role\"] == \"assistant\"\n                assert isinstance(delta[\"content\"], str)\n                outputs[idx] += delta[\"content\"]\n\n            if finished[idx]:\n                assert choice[\"finish_reason\"] in finish_reasons\n                finish_reason_list[idx] = choice[\"finish_reason\"]\n            elif choice[\"finish_reason\"] is not None:\n                assert choice[\"finish_reason\"] in finish_reasons\n                finish_reason_list[idx] = choice[\"finish_reason\"]\n                finished[idx] = True\n\n        if not is_chat_completion:\n            usage = response[\"usage\"]\n            if usage is not None:\n                assert isinstance(usage, dict)\n                assert usage[\"total_tokens\"] == usage[\"prompt_tokens\"] + usage[\"completion_tokens\"]\n                assert usage[\"prompt_tokens\"] >= 0\n                if completion_tokens is not None:\n                    assert usage[\"completion_tokens\"] <= completion_tokens\n\n    if not is_chat_completion:\n        if completion_tokens is not None and responses[-1][\"usage\"] is not None:\n            assert responses[-1][\"usage\"][\"completion_tokens\"] == completion_tokens\n\n    for i, (output, finish_reason) in enumerate(zip(outputs, finish_reason_list)):\n        if echo_prompt is not None:\n            assert output.startswith(echo_prompt)\n        if suffix is not None:\n            assert output.endswith(suffix)\n        if stop is not None:\n            for stop_str in stop:\n                assert stop_str not in output\n        if require_substr is not None:\n            for substr in require_substr:\n                assert substr in output\n        if check_json_output:\n            # the output should be json or a prefix of a json string\n            # if the output is a prefix of a json string, the output must exceed the max output\n            # length\n            output_is_json = is_json(output)\n            output_is_json_prefix = is_json_prefix(output)\n            assert output_is_json or output_is_json_prefix\n            if not output_is_json and output_is_json_prefix:\n                assert finish_reason == \"length\"\n\n\ndef expect_error(response_str: str, msg_prefix: Optional[str] = None):\n    response = json.loads(response_str)\n    assert response[\"object\"] == \"error\"\n    assert isinstance(response[\"message\"], str)\n    if msg_prefix is not None:\n        assert response[\"message\"].startswith(msg_prefix)\n\n\ndef test_openai_v1_models(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    response = requests.get(OPENAI_V1_MODELS_URL, timeout=180).json()\n    assert response[\"object\"] == \"list\"\n    models = response[\"data\"]\n    assert isinstance(models, list)\n    assert len(models) == 1\n\n    model_card = models[0]\n    assert isinstance(model_card, dict)\n    assert model_card[\"id\"] == served_model[0], f\"{model_card['id']} {served_model[0]}\"\n    assert model_card[\"object\"] == \"model\"\n    assert model_card[\"owned_by\"] == \"MLC-LLM\"\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_completions(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    prompt = \"What is the meaning of life?\"\n    max_tokens = 256\n    payload = {\n        \"model\": served_model[0],\n        \"prompt\": prompt,\n        \"max_tokens\": max_tokens,\n        \"stream\": stream,\n        \"debug_config\": {\"ignore_eos\": True},\n    }\n\n    response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)\n    if not stream:\n        check_openai_nonstream_response(\n            response.json(),\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\"],\n            completion_tokens=max_tokens,\n        )\n    else:\n        responses = []\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk or chunk == b\"data: [DONE]\":\n                continue\n            responses.append(json.loads(chunk.decode(\"utf-8\")[6:]))\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\"],\n            completion_tokens=max_tokens,\n        )\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_completions_openai_package(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    client = OpenAI(base_url=OPENAI_BASE_URL, api_key=\"None\")\n    prompt = \"What is the meaning of life?\"\n    max_tokens = 256\n    response = client.completions.create(\n        model=served_model[0],\n        prompt=prompt,\n        max_tokens=max_tokens,\n        stream=stream,\n    )\n    if not stream:\n        check_openai_nonstream_response(\n            response.model_dump(),\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\", \"stop\"],\n            completion_tokens=max_tokens,\n        )\n    else:\n        responses = []\n        for chunk in response:  # pylint: disable=not-an-iterable\n            responses.append(chunk.model_dump())\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\", \"stop\"],\n            completion_tokens=max_tokens,\n        )\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_completions_echo(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    prompt = \"What is the meaning of life?\"\n    max_tokens = 256\n    payload = {\n        \"model\": served_model[0],\n        \"prompt\": prompt,\n        \"max_tokens\": max_tokens,\n        \"echo\": True,\n        \"stream\": stream,\n        \"debug_config\": {\"ignore_eos\": True},\n    }\n\n    response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)\n    if not stream:\n        check_openai_nonstream_response(\n            response.json(),\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\"],\n            completion_tokens=max_tokens,\n            echo_prompt=prompt,\n        )\n    else:\n        responses = []\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk or chunk == b\"data: [DONE]\":\n                continue\n            responses.append(json.loads(chunk.decode(\"utf-8\")[6:]))\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\"],\n            completion_tokens=max_tokens,\n            echo_prompt=prompt,\n        )\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_completions_suffix(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    prompt = \"What is the meaning of life?\"\n    suffix = \"Hello, world!\"\n    max_tokens = 256\n    payload = {\n        \"model\": served_model[0],\n        \"prompt\": prompt,\n        \"max_tokens\": max_tokens,\n        \"suffix\": suffix,\n        \"stream\": stream,\n        \"debug_config\": {\"ignore_eos\": True},\n    }\n\n    response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)\n    if not stream:\n        check_openai_nonstream_response(\n            response.json(),\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\"],\n            completion_tokens=max_tokens,\n            suffix=suffix,\n        )\n    else:\n        responses = []\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk or chunk == b\"data: [DONE]\":\n                continue\n            responses.append(json.loads(chunk.decode(\"utf-8\")[6:]))\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\"],\n            completion_tokens=max_tokens,\n            suffix=suffix,\n        )\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_completions_stop_str(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    # Choose \"in\" as the stop string since it is very unlikely that\n    # \"in\" does not appear in the generated output.\n    prompt = \"What is the meaning of life?\"\n    stop = [\"in\"]\n    max_tokens = 256\n    payload = {\n        \"model\": served_model[0],\n        \"prompt\": prompt,\n        \"max_tokens\": max_tokens,\n        \"stop\": stop,\n        \"stream\": stream,\n    }\n\n    response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)\n    if not stream:\n        check_openai_nonstream_response(\n            response.json(),\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"stop\", \"length\"],\n            stop=stop,\n        )\n    else:\n        responses = []\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk or chunk == b\"data: [DONE]\":\n                continue\n            responses.append(json.loads(chunk.decode(\"utf-8\")[6:]))\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"stop\", \"length\"],\n            stop=stop,\n        )\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_completions_temperature(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    prompt = \"What's the meaning of life?\"\n    max_tokens = 128\n    payload = {\n        \"model\": served_model[0],\n        \"prompt\": prompt,\n        \"max_tokens\": max_tokens,\n        \"stream\": stream,\n        \"temperature\": 0.0,\n        \"debug_config\": {\"ignore_eos\": True},\n    }\n\n    response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)\n    if not stream:\n        check_openai_nonstream_response(\n            response.json(),\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\"],\n        )\n    else:\n        responses = []\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk or chunk == b\"data: [DONE]\":\n                continue\n            responses.append(json.loads(chunk.decode(\"utf-8\")[6:]))\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\"],\n        )\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_completions_json(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    prompt = \"Response with a json object:\"\n    max_tokens = 128\n    payload = {\n        \"model\": served_model[0],\n        \"prompt\": prompt,\n        \"max_tokens\": max_tokens,\n        \"stream\": stream,\n        \"response_format\": {\"type\": \"json_object\"},\n    }\n\n    response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60)\n    if not stream:\n        check_openai_nonstream_response(\n            response.json(),\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\", \"stop\"],\n            check_json_output=True,\n        )\n    else:\n        responses = []\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk or chunk == b\"data: [DONE]\":\n                continue\n            responses.append(json.loads(chunk.decode(\"utf-8\")[6:]))\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\", \"stop\"],\n            check_json_output=True,\n        )\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_completions_json_schema(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    prompt = (\n        \"Generate a json containing three fields: an integer field named size, a \"\n        \"boolean field named is_accepted, and a float field named num:\"\n    )\n    max_tokens = 128\n\n    class Schema(BaseModel):\n        size: int\n        is_accepted: bool\n        num: float\n\n    schema_str = json.dumps(Schema.model_json_schema())\n\n    payload = {\n        \"model\": served_model[0],\n        \"prompt\": prompt,\n        \"max_tokens\": max_tokens,\n        \"stream\": stream,\n        \"response_format\": {\"type\": \"json_object\", \"schema\": schema_str},\n    }\n\n    response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=60)\n    if not stream:\n        check_openai_nonstream_response(\n            response.json(),\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\", \"stop\"],\n            check_json_output=True,\n        )\n    else:\n        responses = []\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk or chunk == b\"data: [DONE]\":\n                continue\n            responses.append(json.loads(chunk.decode(\"utf-8\")[6:]))\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\", \"stop\"],\n            check_json_output=True,\n        )\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_completions_logit_bias(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    # NOTE: This test only tests that the system does not break on logit bias.\n    #       The test does not promise the correctness of logit bias handling.\n\n    prompt = \"What's the meaning of life?\"\n    max_tokens = 128\n    payload = {\n        \"model\": served_model[0],\n        \"prompt\": prompt,\n        \"max_tokens\": max_tokens,\n        \"stream\": stream,\n        \"logit_bias\": {338: -100},  # 338 is \" is\" in Llama tokenizer.\n        \"debug_config\": {\"ignore_eos\": True},\n    }\n\n    response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)\n    if not stream:\n        check_openai_nonstream_response(\n            response.json(),\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\"],\n        )\n    else:\n        responses = []\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk or chunk == b\"data: [DONE]\":\n                continue\n            responses.append(json.loads(chunk.decode(\"utf-8\")[6:]))\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\"],\n        )\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_completions_presence_frequency_penalty(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    prompt = \"What's the meaning of life?\"\n    max_tokens = 128\n    payload = {\n        \"model\": served_model[0],\n        \"prompt\": prompt,\n        \"max_tokens\": max_tokens,\n        \"stream\": stream,\n        \"frequency_penalty\": 2.0,\n        \"presence_penalty\": 2.0,\n        \"debug_config\": {\"ignore_eos\": True},\n    }\n\n    response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)\n    if not stream:\n        check_openai_nonstream_response(\n            response.json(),\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\"],\n        )\n    else:\n        responses = []\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk or chunk == b\"data: [DONE]\":\n                continue\n            responses.append(json.loads(chunk.decode(\"utf-8\")[6:]))\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\"],\n        )\n\n\ndef test_openai_v1_completions_seed(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    prompt = \"What's the meaning of life?\"\n    max_tokens = 128\n    payload = {\n        \"model\": served_model[0],\n        \"prompt\": prompt,\n        \"max_tokens\": max_tokens,\n        \"stream\": False,\n        \"seed\": 233,\n        \"debug_config\": {\"ignore_eos\": True},\n    }\n\n    response1 = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)\n    response2 = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)\n    for response in [response1, response2]:\n        check_openai_nonstream_response(\n            response.json(),\n            is_chat_completion=False,\n            model=served_model[0],\n            object_str=\"text_completion\",\n            num_choices=1,\n            finish_reasons=[\"length\"],\n        )\n\n    text1 = response1.json()[\"choices\"][0][\"text\"]\n    text2 = response2.json()[\"choices\"][0][\"text\"]\n    assert text1 == text2\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_completions_prompt_overlong(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    num_tokens = 1000000\n    prompt = [128] * num_tokens\n    payload = {\n        \"model\": served_model[0],\n        \"prompt\": prompt,\n        \"max_tokens\": 256,\n        \"stream\": stream,\n    }\n\n    response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)\n    error_msg_prefix = (\n        f\"Request prompt has {num_tokens} tokens in total, larger than the model input length limit\"\n    )\n    if not stream:\n        expect_error(response.json(), msg_prefix=error_msg_prefix)\n    else:\n        num_chunks = 0\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk:\n                continue\n            num_chunks += 1\n            expect_error(json.loads(chunk.decode(\"utf-8\")), msg_prefix=error_msg_prefix)\n        assert num_chunks == 1\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_completions_invalid_logprobs(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    payload = {\n        \"model\": served_model[0],\n        \"prompt\": \"What is the meaning of life?\",\n        \"max_tokens\": 256,\n        \"stream\": stream,\n        \"logprobs\": COMPLETION_MAX_TOP_LOGPROBS + 1,\n    }\n\n    response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)\n    response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)\n    assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY\n    assert response.json()[\"detail\"][0][\"msg\"].endswith(\n        f'\"top_logprobs\" must be in range [0, {COMPLETION_MAX_TOP_LOGPROBS}]'\n    )\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_chat_completions_invalid_logprobs(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    payload = {\n        \"model\": served_model[0],\n        \"messages\": [{\"role\": \"user\", \"content\": \"Hello! Our project is MLC LLM.\"}],\n        \"max_tokens\": 256,\n        \"stream\": stream,\n        \"logprobs\": False,\n        \"top_logprobs\": CHAT_COMPLETION_MAX_TOP_LOGPROBS - 1,\n    }\n\n    response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)\n    assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY\n    assert response.json()[\"detail\"][0][\"msg\"].endswith(\n        '\"logprobs\" must be True to support \"top_logprobs\"'\n    )\n\n    payload[\"logprobs\"] = True\n    payload[\"top_logprobs\"] = CHAT_COMPLETION_MAX_TOP_LOGPROBS + 1\n\n    response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)\n    response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)\n    assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY\n    assert response.json()[\"detail\"][0][\"msg\"].endswith(\n        f'\"top_logprobs\" must be in range [0, {CHAT_COMPLETION_MAX_TOP_LOGPROBS}]'\n    )\n\n\ndef test_openai_v1_completions_unsupported_args(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    # Right now \"best_of\" is unsupported.\n    best_of = 2\n    payload = {\n        \"model\": served_model[0],\n        \"prompt\": \"What is the meaning of life?\",\n        \"max_tokens\": 256,\n        \"best_of\": best_of,\n    }\n\n    response = requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=180)\n    error_msg_prefix = 'Request fields \"best_of\" are not supported right now.'\n    expect_error(response.json(), msg_prefix=error_msg_prefix)\n\n\ndef test_openai_v1_completions_request_cancellation(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    # Use a large max_tokens and small timeout to force timeouts.\n    payload = {\n        \"model\": served_model[0],\n        \"prompt\": \"What is the meaning of life?\",\n        \"max_tokens\": 2048,\n        \"stream\": False,\n    }\n    with pytest.raises(requests.exceptions.Timeout):\n        requests.post(OPENAI_V1_COMPLETION_URL, json=payload, timeout=1)\n\n    # The server should still be alive after a request cancelled.\n    # We query `v1/models` to validate the server liveness.\n    response = requests.get(OPENAI_V1_MODELS_URL, timeout=180).json()\n\n    assert response[\"object\"] == \"list\"\n    models = response[\"data\"]\n    assert isinstance(models, list)\n    assert len(models) == 1\n\n    model_card = models[0]\n    assert isinstance(model_card, dict)\n    assert model_card[\"id\"] == served_model[0]\n    assert model_card[\"object\"] == \"model\"\n    assert model_card[\"owned_by\"] == \"MLC-LLM\"\n\n\nCHAT_COMPLETION_MESSAGES = [\n    # messages #0\n    [{\"role\": \"user\", \"content\": \"Hello! Our project is MLC LLM.\"}],\n    # messages #1\n    [\n        {\"role\": \"user\", \"content\": \"Hello! Our project is MLC LLM.\"},\n        {\n            \"role\": \"assistant\",\n            \"content\": \"Hello! It's great to hear about your project, MLC LLM.\",\n        },\n        {\"role\": \"user\", \"content\": \"What is the name of our project?\"},\n    ],\n    # messages #2\n    [\n        {\n            \"role\": \"system\",\n            \"content\": \"You are a helpful, respectful and honest assistant. \"\n            \"You always ends your response with an emoji.\",\n        },\n        {\"role\": \"user\", \"content\": \"Hello! Our project is MLC LLM.\"},\n    ],\n]\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\n@pytest.mark.parametrize(\"messages\", CHAT_COMPLETION_MESSAGES)\ndef test_openai_v1_chat_completions(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n    messages: List[Dict[str, str]],\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    payload = {\n        \"model\": served_model[0],\n        \"messages\": messages,\n        \"stream\": stream,\n    }\n\n    response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180)\n    if not stream:\n        check_openai_nonstream_response(\n            response.json(),\n            is_chat_completion=True,\n            model=served_model[0],\n            object_str=\"chat.completion\",\n            num_choices=1,\n            finish_reasons=[\"stop\"],\n        )\n    else:\n        responses = []\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk or chunk == b\"data: [DONE]\":\n                continue\n            responses.append(json.loads(chunk.decode(\"utf-8\")[6:]))\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=True,\n            model=served_model[0],\n            object_str=\"chat.completion.chunk\",\n            num_choices=1,\n            finish_reasons=[\"stop\"],\n        )\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\n@pytest.mark.parametrize(\"messages\", CHAT_COMPLETION_MESSAGES)\ndef test_openai_v1_chat_completions_n(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n    messages: List[Dict[str, str]],\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    n = 3\n    payload = {\n        \"model\": served_model[0],\n        \"messages\": messages,\n        \"stream\": stream,\n        \"n\": n,\n        \"max_tokens\": 300,\n    }\n\n    response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180)\n    if not stream:\n        check_openai_nonstream_response(\n            response.json(),\n            is_chat_completion=True,\n            model=served_model[0],\n            object_str=\"chat.completion\",\n            num_choices=n,\n            finish_reasons=[\"stop\", \"length\"],\n        )\n    else:\n        responses = []\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk or chunk == b\"data: [DONE]\":\n                continue\n            responses.append(json.loads(chunk.decode(\"utf-8\")[6:]))\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=True,\n            model=served_model[0],\n            object_str=\"chat.completion.chunk\",\n            num_choices=n,\n            finish_reasons=[\"stop\", \"length\"],\n        )\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\n@pytest.mark.parametrize(\"messages\", CHAT_COMPLETION_MESSAGES)\ndef test_openai_v1_chat_completions_openai_package(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n    messages: List[Dict[str, str]],\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    client = OpenAI(base_url=OPENAI_BASE_URL, api_key=\"None\")\n    response = client.chat.completions.create(\n        model=served_model[0],\n        messages=messages,\n        stream=stream,\n        logprobs=True,\n        top_logprobs=2,\n    )\n    if not stream:\n        check_openai_nonstream_response(\n            response.model_dump(),\n            is_chat_completion=True,\n            model=served_model[0],\n            object_str=\"chat.completion\",\n            num_choices=1,\n            finish_reasons=[\"stop\"],\n        )\n    else:\n        responses = []\n        for chunk in response:  # pylint: disable=not-an-iterable\n            responses.append(chunk.model_dump())\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=True,\n            model=served_model[0],\n            object_str=\"chat.completion.chunk\",\n            num_choices=1,\n            finish_reasons=[\"stop\"],\n        )\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_chat_completions_max_tokens(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    messages = [{\"role\": \"user\", \"content\": \"Write a novel with at least 500 words.\"}]\n    max_tokens = 16\n    payload = {\n        \"model\": served_model[0],\n        \"messages\": messages,\n        \"stream\": stream,\n        \"max_tokens\": max_tokens,\n    }\n\n    response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180)\n    if not stream:\n        check_openai_nonstream_response(\n            response.json(),\n            is_chat_completion=True,\n            model=served_model[0],\n            object_str=\"chat.completion\",\n            num_choices=1,\n            finish_reasons=[\"length\"],\n            completion_tokens=max_tokens,\n        )\n    else:\n        responses = []\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk or chunk == b\"data: [DONE]\":\n                continue\n            responses.append(json.loads(chunk.decode(\"utf-8\")[6:]))\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=True,\n            model=served_model[0],\n            object_str=\"chat.completion.chunk\",\n            num_choices=1,\n            finish_reasons=[\"length\"],\n            completion_tokens=max_tokens,\n        )\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_chat_completions_json(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    messages = [{\"role\": \"user\", \"content\": \"Response with a json object:\"}]\n    max_tokens = 128\n    payload = {\n        \"model\": served_model[0],\n        \"messages\": messages,\n        \"stream\": stream,\n        \"max_tokens\": max_tokens,\n        \"response_format\": {\"type\": \"json_object\"},\n    }\n\n    response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60)\n    if not stream:\n        check_openai_nonstream_response(\n            response.json(),\n            is_chat_completion=True,\n            model=served_model[0],\n            object_str=\"chat.completion\",\n            num_choices=1,\n            finish_reasons=[\"length\", \"stop\"],\n            check_json_output=True,\n        )\n    else:\n        responses = []\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk or chunk == b\"data: [DONE]\":\n                continue\n            responses.append(json.loads(chunk.decode(\"utf-8\")[6:]))\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=True,\n            model=served_model[0],\n            object_str=\"chat.completion.chunk\",\n            num_choices=1,\n            finish_reasons=[\"length\", \"stop\"],\n            check_json_output=True,\n        )\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_chat_completions_json_schema(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    prompt = (\n        \"Generate a json containing three fields: an integer field named size, a \"\n        \"boolean field named is_accepted, and a float field named num:\"\n    )\n    messages = [{\"role\": \"user\", \"content\": prompt}]\n    max_tokens = 128\n\n    class Schema(BaseModel):\n        size: int\n        is_accepted: bool\n        num: float\n\n    schema_str = json.dumps(Schema.model_json_schema())\n\n    payload = {\n        \"model\": served_model[0],\n        \"messages\": messages,\n        \"stream\": stream,\n        \"max_tokens\": max_tokens,\n        \"response_format\": {\"type\": \"json_object\", \"schema\": schema_str},\n    }\n\n    response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60)\n    if not stream:\n        check_openai_nonstream_response(\n            response.json(),\n            is_chat_completion=True,\n            model=served_model[0],\n            object_str=\"chat.completion\",\n            num_choices=1,\n            finish_reasons=[\"length\", \"stop\"],\n            check_json_output=True,\n        )\n    else:\n        responses = []\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk or chunk == b\"data: [DONE]\":\n                continue\n            responses.append(json.loads(chunk.decode(\"utf-8\")[6:]))\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=True,\n            model=served_model[0],\n            object_str=\"chat.completion.chunk\",\n            num_choices=1,\n            finish_reasons=[\"length\", \"stop\"],\n            check_json_output=True,\n        )\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_chat_completions_ignore_eos(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    messages = [{\"role\": \"user\", \"content\": \"Write a sentence with less than 20 words.\"}]\n    max_tokens = 128\n    payload = {\n        \"model\": served_model[0],\n        \"messages\": messages,\n        \"stream\": stream,\n        \"max_tokens\": max_tokens,\n        \"debug_config\": {\"ignore_eos\": True},\n    }\n\n    response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180)\n    if not stream:\n        check_openai_nonstream_response(\n            response.json(),\n            is_chat_completion=True,\n            model=served_model[0],\n            object_str=\"chat.completion\",\n            num_choices=1,\n            finish_reasons=[\"length\"],\n            completion_tokens=max_tokens,\n        )\n    else:\n        responses = []\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk or chunk == b\"data: [DONE]\":\n                continue\n            responses.append(json.loads(chunk.decode(\"utf-8\")[6:]))\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=True,\n            model=served_model[0],\n            object_str=\"chat.completion.chunk\",\n            num_choices=1,\n            finish_reasons=[\"length\"],\n            completion_tokens=max_tokens,\n        )\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\ndef test_openai_v1_chat_completions_system_prompt_wrong_pos(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    messages = [\n        {\"role\": \"user\", \"content\": \"Hello! Our project is MLC LLM.\"},\n        {\n            \"role\": \"system\",\n            \"content\": \"You are a helpful, respectful and honest assistant. \"\n            \"You always ends your response with an emoji.\",\n        },\n    ]\n    payload = {\n        \"model\": served_model[0],\n        \"messages\": messages,\n        \"stream\": stream,\n    }\n\n    response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180)\n    error_msg = \"System prompt at position 1 in the message list is invalid.\"\n    if not stream:\n        expect_error(response.json(), msg_prefix=error_msg)\n    else:\n        num_chunks = 0\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk:\n                continue\n            num_chunks += 1\n            expect_error(json.loads(chunk.decode(\"utf-8\")), msg_prefix=error_msg)\n        assert num_chunks == 1\n\n\ndef test_debug_dump_event_trace(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n    # We only check that the request does not fail.\n    payload = {\"model\": served_model[0]}\n    response = requests.post(DEBUG_DUMP_EVENT_TRACE_URL, json=payload, timeout=180)\n    assert response.status_code == HTTPStatus.OK\n\n\ndef test_metrics(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n    # We only check that the request does not fail.\n    metrics_text = requests.get(METRICS_URL, timeout=180).text\n    assert \"engine_prefill_time_sum\" in metrics_text\n\n\nif __name__ == \"__main__\":\n    model_lib = os.environ.get(\"MLC_SERVE_MODEL_LIB\")\n    if model_lib is None:\n        raise ValueError(\n            'Environment variable \"MLC_SERVE_MODEL_LIB\" not found. '\n            \"Please set it to model lib compiled by MLC LLM \"\n            \"(e.g., `dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so`).\"\n        )\n    MODEL = (os.path.dirname(model_lib), model_lib)\n\n    test_openai_v1_models(MODEL, None)\n\n    test_openai_v1_completions(MODEL, None, stream=False)\n    test_openai_v1_completions(MODEL, None, stream=True)\n    test_openai_v1_completions_openai_package(MODEL, None, stream=False)\n    test_openai_v1_completions_openai_package(MODEL, None, stream=True)\n    test_openai_v1_completions_echo(MODEL, None, stream=False)\n    test_openai_v1_completions_echo(MODEL, None, stream=True)\n    test_openai_v1_completions_suffix(MODEL, None, stream=False)\n    test_openai_v1_completions_suffix(MODEL, None, stream=True)\n    test_openai_v1_completions_stop_str(MODEL, None, stream=False)\n    test_openai_v1_completions_stop_str(MODEL, None, stream=True)\n    test_openai_v1_completions_temperature(MODEL, None, stream=False)\n    test_openai_v1_completions_temperature(MODEL, None, stream=True)\n    test_openai_v1_completions_logit_bias(MODEL, None, stream=False)\n    test_openai_v1_completions_logit_bias(MODEL, None, stream=True)\n    test_openai_v1_completions_presence_frequency_penalty(MODEL, None, stream=False)\n    test_openai_v1_completions_presence_frequency_penalty(MODEL, None, stream=True)\n    test_openai_v1_completions_seed(MODEL, None)\n    test_openai_v1_completions_prompt_overlong(MODEL, None, stream=False)\n    test_openai_v1_completions_prompt_overlong(MODEL, None, stream=True)\n    test_openai_v1_completions_invalid_logprobs(MODEL, None, stream=False)\n    test_openai_v1_completions_invalid_logprobs(MODEL, None, stream=True)\n    test_openai_v1_completions_unsupported_args(MODEL, None)\n    test_openai_v1_completions_request_cancellation(MODEL, None)\n\n    for msg in CHAT_COMPLETION_MESSAGES:\n        test_openai_v1_chat_completions(MODEL, None, stream=False, messages=msg)\n        test_openai_v1_chat_completions(MODEL, None, stream=True, messages=msg)\n        test_openai_v1_chat_completions_n(MODEL, None, stream=False, messages=msg)\n        test_openai_v1_chat_completions_n(MODEL, None, stream=True, messages=msg)\n        test_openai_v1_chat_completions_openai_package(MODEL, None, stream=False, messages=msg)\n        test_openai_v1_chat_completions_openai_package(MODEL, None, stream=True, messages=msg)\n    test_openai_v1_chat_completions_max_tokens(MODEL, None, stream=False)\n    test_openai_v1_chat_completions_max_tokens(MODEL, None, stream=True)\n    test_openai_v1_chat_completions_json(MODEL, None, stream=False)\n    test_openai_v1_chat_completions_json(MODEL, None, stream=True)\n    test_openai_v1_chat_completions_ignore_eos(MODEL, None, stream=False)\n    test_openai_v1_chat_completions_ignore_eos(MODEL, None, stream=True)\n    test_openai_v1_chat_completions_system_prompt_wrong_pos(MODEL, None, stream=False)\n    test_openai_v1_chat_completions_system_prompt_wrong_pos(MODEL, None, stream=True)\n\n    test_debug_dump_event_trace(MODEL, None)\n"
  },
  {
    "path": "tests/python/serve/server/test_server_function_call.py",
    "content": "# pylint: disable=line-too-long\n\"\"\"\nTest script for function call in chat completion. To run this script, use the following command:\nMLC_SERVE_MODEL_LIB=dist/gorilla-openfunctions-v1-q4f16_1_MLC/gorilla-openfunctions-v1-q4f16_1-cuda.so\nMLC_SERVE_MODEL_LIB=${MLC_SERVE_MODEL_LIB} python -m pytest -x tests/python/serve/server/test_server_function_call.py\n\"\"\"\n\n# pylint: disable=missing-function-docstring,too-many-arguments,too-many-locals,too-many-branches\nimport json\nimport os\nfrom typing import Dict, List, Optional, Tuple\n\nimport pytest\nimport requests\n\nOPENAI_V1_CHAT_COMPLETION_URL = \"http://127.0.0.1:8000/v1/chat/completions\"\n\n\ndef check_openai_nonstream_response(\n    response: Dict,\n    *,\n    model: str,\n    object_str: str,\n    num_choices: int,\n    finish_reason: List[str],\n    completion_tokens: Optional[int] = None,\n):\n    print(response)\n    assert response[\"model\"] == model\n    assert response[\"object\"] == object_str\n\n    choices = response[\"choices\"]\n    assert isinstance(choices, list)\n    assert len(choices) == num_choices\n    for idx, choice in enumerate(choices):\n        assert choice[\"index\"] == idx\n        assert choice[\"finish_reason\"] in finish_reason\n\n        # text: str\n        message = choice[\"message\"]\n        assert message[\"role\"] == \"assistant\"\n        if choice[\"finish_reason\"] == \"tool_calls\":\n            assert message[\"content\"] is None\n            assert isinstance(message[\"tool_calls\"], list)\n        else:\n            assert message[\"tool_calls\"] is None\n            assert message[\"content\"] is not None\n\n    usage = response[\"usage\"]\n    assert isinstance(usage, dict)\n    assert usage[\"total_tokens\"] == usage[\"prompt_tokens\"] + usage[\"completion_tokens\"]\n    assert usage[\"prompt_tokens\"] > 0\n\n    if completion_tokens is not None:\n        assert usage[\"completion_tokens\"] == completion_tokens\n\n\ndef check_openai_stream_response(\n    responses: List[Dict],\n    *,\n    model: str,\n    object_str: str,\n    num_choices: int,\n    finish_reason: str,\n    echo_prompt: Optional[str] = None,\n    suffix: Optional[str] = None,\n    stop: Optional[List[str]] = None,\n    require_substr: Optional[List[str]] = None,\n):\n    assert len(responses) > 0\n\n    finished = [False for _ in range(num_choices)]\n    outputs = [\"\" for _ in range(num_choices)]\n    for response in responses:\n        assert response[\"model\"] == model\n        assert response[\"object\"] == object_str\n\n        choices = response[\"choices\"]\n        assert isinstance(choices, list)\n        assert len(choices) == num_choices\n        for idx, choice in enumerate(choices):\n            assert choice[\"index\"] == idx\n\n            delta = choice[\"delta\"]\n            assert delta[\"role\"] == \"assistant\"\n            assert isinstance(delta[\"content\"], str)\n            outputs[idx] += delta[\"content\"]\n\n            if finished[idx]:\n                assert choice[\"finish_reason\"] == finish_reason\n            elif choice[\"finish_reason\"] is not None:\n                assert choice[\"finish_reason\"] == finish_reason\n                finished[idx] = True\n\n    for output in outputs:\n        if echo_prompt is not None:\n            assert output.startswith(echo_prompt)\n        if suffix is not None:\n            assert output.endswith(suffix)\n        if stop is not None:\n            for stop_str in stop:\n                assert stop_str not in output\n        if require_substr is not None:\n            for substr in require_substr:\n                assert substr in output\n\n\ntools = [\n    {\n        \"type\": \"function\",\n        \"function\": {\n            \"name\": \"get_current_weather\",\n            \"description\": \"Get the current weather in a given location\",\n            \"parameters\": {\n                \"type\": \"object\",\n                \"properties\": {\n                    \"location\": {\n                        \"type\": \"string\",\n                        \"description\": \"The city and state, e.g. San Francisco, CA\",\n                    },\n                    \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n                },\n                \"required\": [\"location\"],\n            },\n        },\n    }\n]\n\n\nCHAT_COMPLETION_MESSAGES = [\n    # messages #0\n    [\n        {\n            \"role\": \"user\",\n            \"content\": \"What is the current weather in Pittsburgh, PA?\",\n        }\n    ],\n    # messages #1\n    [\n        {\n            \"role\": \"user\",\n            \"content\": \"What is the current weather in Pittsburgh, PA and Tokyo, JP?\",\n        }\n    ],\n    # messages #2\n    [\n        {\n            \"role\": \"user\",\n            \"content\": \"What is the current weather in Pittsburgh, PA in fahrenheit?\",\n        }\n    ],\n]\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\n@pytest.mark.parametrize(\"messages\", CHAT_COMPLETION_MESSAGES)\ndef test_openai_v1_chat_completion_function_call(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n    messages: List[Dict[str, str]],\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    payload = {\n        \"model\": served_model[0],\n        \"messages\": messages,\n        \"stream\": stream,\n        \"tools\": tools,\n    }\n\n    response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60)\n    if not stream:\n        check_openai_nonstream_response(\n            response.json(),\n            model=served_model[0],\n            object_str=\"chat.completion\",\n            num_choices=1,\n            finish_reason=[\"tool_calls\", \"error\"],\n        )\n    else:\n        responses = []\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk or chunk == b\"data: [DONE]\":\n                continue\n            responses.append(json.loads(chunk.decode(\"utf-8\")[6:]))\n        check_openai_stream_response(\n            responses,\n            model=served_model[0],\n            object_str=\"chat.completion.chunk\",\n            num_choices=1,\n            finish_reason=\"tool_calls\",\n        )\n\n\nif __name__ == \"__main__\":\n    model_lib = os.environ.get(\"MLC_SERVE_MODEL_LIB\")\n    if model_lib is None:\n        raise ValueError(\n            'Environment variable \"MLC_SERVE_MODEL_LIB\" not found. '\n            \"Please set it to model lib compiled by MLC LLM \"\n            \"(e.g., `./dist/gorilla-openfunctions-v1-q4f16_1_MLC/gorilla-openfunctions-v1-q4f16_1-cuda.so`) \"\n            \"which supports function calls.\"\n        )\n    MODEL = (os.path.dirname(model_lib), model_lib)\n\n    for msg in CHAT_COMPLETION_MESSAGES:\n        test_openai_v1_chat_completion_function_call(MODEL, None, stream=False, messages=msg)\n        test_openai_v1_chat_completion_function_call(MODEL, None, stream=True, messages=msg)\n"
  },
  {
    "path": "tests/python/serve/server/test_server_image.py",
    "content": "# pylint: disable=missing-function-docstring,too-many-arguments,too-many-locals,too-many-branches\nimport json\nimport os\nfrom typing import Dict, List, Optional, Tuple\n\nimport pytest\nimport regex\nimport requests\n\nOPENAI_V1_CHAT_COMPLETION_URL = \"http://127.0.0.1:8001/v1/chat/completions\"\n\nJSON_TOKEN_PATTERN = (\n    r\"((-?(?:0|[1-9]\\d*))(\\.\\d+)?([eE][-+]?\\d+)?)|null|true|false|\"\n    r'(\"((\\\\[\"\\\\\\/bfnrt])|(\\\\u[0-9a-fA-F]{4})|[^\"\\\\\\x00-\\x1f])*\")'\n)\nJSON_TOKEN_RE = regex.compile(JSON_TOKEN_PATTERN)\n\n\ndef is_json_or_json_prefix(s: str) -> bool:\n    try:\n        json.loads(s)\n        return True\n    except json.JSONDecodeError as e:\n        # If the JSON decoder reaches the end of s, it is a prefix of a JSON string.\n        if e.pos == len(s):\n            return True\n        # Since json.loads is token-based instead of char-based, there may remain half a token after\n        # the matching position.\n        # If the left part is a prefix of a valid JSON token, the output is also valid\n        regex_match = JSON_TOKEN_RE.fullmatch(s[e.pos :], partial=True)\n        return regex_match is not None\n\n\ndef check_openai_nonstream_response(\n    response: Dict,\n    *,\n    is_chat_completion: bool,\n    model: str,\n    object_str: str,\n    num_choices: int,\n    finish_reasons: List[str],\n    completion_tokens: Optional[int] = None,\n    echo_prompt: Optional[str] = None,\n    suffix: Optional[str] = None,\n    stop: Optional[List[str]] = None,\n    require_substr: Optional[List[str]] = None,\n    json_mode: bool = False,\n):\n    assert response[\"model\"] == model\n    assert response[\"object\"] == object_str\n\n    choices = response[\"choices\"]\n    assert isinstance(choices, list)\n    assert len(choices) <= num_choices\n    texts: List[str] = [\"\" for _ in range(num_choices)]\n    for choice in choices:\n        idx = choice[\"index\"]\n        assert choice[\"finish_reason\"] in finish_reasons\n\n        if not is_chat_completion:\n            assert isinstance(choice[\"text\"], str)\n            texts[idx] = choice[\"text\"]\n            if echo_prompt is not None:\n                assert texts[idx]\n            if suffix is not None:\n                assert texts[idx]\n        else:\n            message = choice[\"message\"]\n            assert message[\"role\"] == \"assistant\"\n            assert isinstance(message[\"content\"], str)\n            texts[idx] = message[\"content\"]\n\n        if stop is not None:\n            for stop_str in stop:\n                assert stop_str not in texts[idx]\n        if require_substr is not None:\n            for substr in require_substr:\n                assert substr in texts[idx]\n        if json_mode:\n            assert is_json_or_json_prefix(texts[idx])\n\n    usage = response[\"usage\"]\n    assert isinstance(usage, dict)\n    assert usage[\"total_tokens\"] == usage[\"prompt_tokens\"] + usage[\"completion_tokens\"]\n    assert usage[\"prompt_tokens\"] > 0\n    if completion_tokens is not None:\n        assert usage[\"completion_tokens\"] == completion_tokens\n\n\ndef check_openai_stream_response(\n    responses: List[Dict],\n    *,\n    is_chat_completion: bool,\n    model: str,\n    object_str: str,\n    num_choices: int,\n    finish_reasons: List[str],\n    completion_tokens: Optional[int] = None,\n    echo_prompt: Optional[str] = None,\n    suffix: Optional[str] = None,\n    stop: Optional[List[str]] = None,\n    require_substr: Optional[List[str]] = None,\n    json_mode: bool = False,\n):\n    assert len(responses) > 0\n\n    finished = [False for _ in range(num_choices)]\n    outputs = [\"\" for _ in range(num_choices)]\n    for response in responses:\n        assert response[\"model\"] == model\n        assert response[\"object\"] == object_str\n\n        choices = response[\"choices\"]\n        assert isinstance(choices, list)\n        assert len(choices) <= num_choices\n        for choice in choices:\n            idx = choice[\"index\"]\n\n            if not is_chat_completion:\n                assert isinstance(choice[\"text\"], str)\n                outputs[idx] += choice[\"text\"]\n            else:\n                delta = choice[\"delta\"]\n                assert delta[\"role\"] == \"assistant\"\n                assert isinstance(delta[\"content\"], str)\n                outputs[idx] += delta[\"content\"]\n\n            if finished[idx]:\n                assert choice[\"finish_reason\"] in finish_reasons\n            elif choice[\"finish_reason\"] is not None:\n                assert choice[\"finish_reason\"] in finish_reasons\n                finished[idx] = True\n\n        if not is_chat_completion:\n            usage = response[\"usage\"]\n            assert isinstance(usage, dict)\n            assert usage[\"total_tokens\"] == usage[\"prompt_tokens\"] + usage[\"completion_tokens\"]\n            assert usage[\"prompt_tokens\"] > 0\n            if completion_tokens is not None:\n                assert usage[\"completion_tokens\"] <= completion_tokens\n\n    if not is_chat_completion:\n        if completion_tokens is not None:\n            assert responses[-1][\"usage\"][\"completion_tokens\"] == completion_tokens\n\n    for i, output in enumerate(outputs):\n        if echo_prompt is not None:\n            assert output.startswith(echo_prompt)\n        if suffix is not None:\n            assert output.endswith(suffix)\n        if stop is not None:\n            for stop_str in stop:\n                assert stop_str not in output\n        if require_substr is not None:\n            for substr in require_substr:\n                assert substr in output\n        if json_mode:\n            assert is_json_or_json_prefix(output)\n\n\nCHAT_COMPLETION_MESSAGES = [\n    # messages #0\n    [\n        {\n            \"role\": \"user\",\n            \"content\": [\n                {\n                    \"type\": \"image_url\",\n                    \"image_url\": \"https://llava-vl.github.io/static/images/view.jpg\",\n                },\n                {\"type\": \"text\", \"text\": \"What does this image represent?\"},\n            ],\n        },\n    ],\n    # messages #1\n    [\n        {\n            \"role\": \"user\",\n            \"content\": [\n                {\n                    \"type\": \"image_url\",\n                    \"image_url\": \"https://llava-vl.github.io/static/images/view.jpg\",\n                },\n                {\"type\": \"text\", \"text\": \"What does this image represent?\"},\n            ],\n        },\n        {\n            \"role\": \"assistant\",\n            \"content\": \"The image represents a serene and peaceful scene of a pier extending over a body of water, such as a lake or a river.er. The pier is made of wood and has a bench on it, providing a place for people to sit and enjoy the view. The pier is situated in a natural environment, surrounded by trees and mountains in the background. This setting creates a tranquil atmosphere, inviting visitors to relax and appreciate the beauty of the landscape.\",\n        },\n        {\n            \"role\": \"user\",\n            \"content\": \"What country is the image set in? Give me 10 ranked guesses and reasons why.\",\n        },\n    ],\n]\n\n\n@pytest.mark.parametrize(\"stream\", [False, True])\n@pytest.mark.parametrize(\"messages\", CHAT_COMPLETION_MESSAGES)\ndef test_openai_v1_chat_completions(\n    served_model: Tuple[str, str],\n    launch_server,  # pylint: disable=unused-argument\n    stream: bool,\n    messages: List[Dict[str, str]],\n):\n    # `served_model` and `launch_server` are pytest fixtures\n    # defined in conftest.py.\n\n    payload = {\n        \"model\": served_model[0],\n        \"messages\": messages,\n        \"stream\": stream,\n    }\n    response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=180)\n    if not stream:\n        check_openai_nonstream_response(\n            response.json(),\n            is_chat_completion=True,\n            model=served_model[0],\n            object_str=\"chat.completion\",\n            num_choices=1,\n            finish_reasons=[\"stop\"],\n        )\n    else:\n        responses = []\n        for chunk in response.iter_lines(chunk_size=512):\n            if not chunk or chunk == b\"data: [DONE]\":\n                continue\n            responses.append(json.loads(chunk.decode(\"utf-8\")[6:]))\n        check_openai_stream_response(\n            responses,\n            is_chat_completion=True,\n            model=served_model[0],\n            object_str=\"chat.completion.chunk\",\n            num_choices=1,\n            finish_reasons=[\"stop\"],\n        )\n\n\nif __name__ == \"__main__\":\n    model_lib = os.environ.get(\"MLC_SERVE_MODEL_LIB\")\n    if model_lib is None:\n        raise ValueError(\n            'Environment variable \"MLC_SERVE_MODEL_LIB\" not found. '\n            \"Please set it to model lib compiled by MLC LLM \"\n            \"(e.g., `dist/Llama-2-7b-chat-hf-q0f16-MLC/Llama-2-7b-chat-hf-q0f16-MLC-cuda.so`).\"\n        )\n\n    model = os.environ.get(\"MLC_SERVE_MODEL\")\n    if model is None:\n        MODEL = (os.path.dirname(model_lib), model_lib)\n    else:\n        MODEL = (model, model_lib)\n\n    for msg in CHAT_COMPLETION_MESSAGES:\n        test_openai_v1_chat_completions(MODEL, None, stream=False, messages=msg)\n        test_openai_v1_chat_completions(MODEL, None, stream=True, messages=msg)\n"
  },
  {
    "path": "tests/python/serve/test_embedding_engine.py",
    "content": "\"\"\"Embedding engine tests in MLC LLM.\n\nTests AsyncEmbeddingEngine for both direct (sync) and async embedding inference.\nReuses MLC LLM test infrastructure: markers, require_test_model pattern,\nand conventions from test_serve_engine.py.\n\nRun with real model (requires GPU + compiled embedding model):\n  MLC_SERVE_EMBEDDING_MODEL_LIB=\"path/to/model.dylib\" \\\n    pytest -m engine tests/python/serve/test_embedding_engine.py -v\n\nEnvironment variables:\n  MLC_SERVE_EMBEDDING_MODEL_LIB  Path to compiled embedding model library (required)\n  MLC_SERVE_EMBEDDING_MODEL      Path to embedding model weight directory\n                                  (optional, defaults to dirname of model lib)\n\"\"\"\n\n# pylint: disable=import-outside-toplevel,protected-access,redefined-outer-name\n\nimport asyncio\nimport os\n\nimport numpy as np\nimport pytest\n\n# Reuse MLC LLM marker system (registered in tests/python/conftest.py)\npytestmark = [pytest.mark.engine]\n\n# ---------------------------------------------------------------------------\n# Fixtures — follows pattern from serve/server/conftest.py (served_model)\n# ---------------------------------------------------------------------------\n\nEMBEDDING_MODEL_LIB = os.environ.get(\"MLC_SERVE_EMBEDDING_MODEL_LIB\")\nEMBEDDING_MODEL_DIR = os.environ.get(\n    \"MLC_SERVE_EMBEDDING_MODEL\",\n    os.path.dirname(EMBEDDING_MODEL_LIB) if EMBEDDING_MODEL_LIB else None,\n)\n\n\ndef _skip_if_no_model():\n    if EMBEDDING_MODEL_LIB is None:\n        pytest.skip(\n            'Environment variable \"MLC_SERVE_EMBEDDING_MODEL_LIB\" not found. '\n            \"Set it to a compiled embedding model library \"\n            \"(e.g., Qwen3-Embedding-0.6B-q0f32-MLC.dylib).\"\n        )\n    if not os.path.isfile(EMBEDDING_MODEL_LIB):\n        pytest.skip(f\"Embedding model library not found at: {EMBEDDING_MODEL_LIB}\")\n    if EMBEDDING_MODEL_DIR is None or not os.path.isdir(EMBEDDING_MODEL_DIR):\n        pytest.skip(f\"Embedding model directory not found at: {EMBEDDING_MODEL_DIR}\")\n\n\n@pytest.fixture(scope=\"module\")\ndef embedding_engine():\n    \"\"\"Module-scoped AsyncEmbeddingEngine — loaded once, shared across tests.\"\"\"\n    _skip_if_no_model()\n    from mlc_llm.serve.embedding_engine import AsyncEmbeddingEngine\n\n    engine = AsyncEmbeddingEngine(\n        model=EMBEDDING_MODEL_DIR,\n        model_lib=EMBEDDING_MODEL_LIB,\n        device=\"auto\",\n    )\n    yield engine\n    engine.terminate()\n\n\n# ---------------------------------------------------------------------------\n# Helpers — reuse cosine_similarity pattern from test_serve_engine.py\n# ---------------------------------------------------------------------------\n\n\ndef cosine_similarity(a, b):\n    \"\"\"Return cosine similarity between two vectors.\"\"\"\n    a, b = np.array(a), np.array(b)\n    return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))\n\n\n# ===================================================================\n# Engine initialization tests\n# ===================================================================\n\n\ndef test_engine_model_type(embedding_engine):\n    \"\"\"Engine reports a valid model type.\"\"\"\n    assert embedding_engine.model_type in (\"encoder\", \"decoder\")\n\n\ndef test_engine_pooling_strategy(embedding_engine):\n    \"\"\"Engine selects appropriate default pooling strategy.\"\"\"\n    if embedding_engine.model_type == \"encoder\":\n        assert embedding_engine.pooling_strategy == \"cls\"\n    else:\n        assert embedding_engine.pooling_strategy == \"last\"\n\n\n# ===================================================================\n# Single-text embedding\n# ===================================================================\n\n\ndef test_single_text_shape(embedding_engine):\n    \"\"\"Single text returns exactly one embedding vector.\"\"\"\n    embeddings, tokens = embedding_engine.embed([\"Hello world\"])\n    assert len(embeddings) == 1\n    assert len(embeddings[0]) > 0\n    assert tokens > 0\n\n\ndef test_single_text_unit_norm(embedding_engine):\n    \"\"\"Embedding output is L2-normalized.\"\"\"\n    embeddings, _ = embedding_engine.embed([\"Hello world\"])\n    norm = float(np.linalg.norm(embeddings[0]))\n    assert abs(norm - 1.0) < 1e-4, f\"Expected unit norm, got {norm}\"\n\n\n# ===================================================================\n# Batch embedding\n# ===================================================================\n\nBATCH_TEXTS = [\n    \"Machine learning is fascinating\",\n    \"I love pizza\",\n    \"Deep learning uses neural networks\",\n]\n\n\ndef test_batch_count(embedding_engine):\n    \"\"\"Batch embedding returns one vector per input.\"\"\"\n    embeddings, tokens = embedding_engine.embed(BATCH_TEXTS)\n    assert len(embeddings) == len(BATCH_TEXTS)\n    assert tokens > 0\n\n\ndef test_batch_all_normalized(embedding_engine):\n    \"\"\"Every vector in a batch is L2-normalized.\"\"\"\n    embeddings, _ = embedding_engine.embed(BATCH_TEXTS)\n    for i, emb in enumerate(embeddings):\n        norm = float(np.linalg.norm(emb))\n        assert abs(norm - 1.0) < 1e-4, f\"Embedding [{i}] norm={norm}\"\n\n\ndef test_batch_consistent_dimension(embedding_engine):\n    \"\"\"All embeddings in a batch have the same dimension.\"\"\"\n    embeddings, _ = embedding_engine.embed(BATCH_TEXTS)\n    dims = {len(emb) for emb in embeddings}\n    assert len(dims) == 1, f\"Inconsistent dimensions: {dims}\"\n\n\n# ===================================================================\n# Semantic quality — cosine similarity ranking\n# ===================================================================\n\nSIMILARITY_TEXTS = [\n    \"What is machine learning?\",\n    \"Explain deep learning algorithms\",\n    \"I want to order pizza\",\n]\n\n\ndef test_cosine_similarity_ranking(embedding_engine):\n    \"\"\"Related texts have higher cosine similarity than unrelated texts.\"\"\"\n    embeddings, _ = embedding_engine.embed(SIMILARITY_TEXTS)\n    e_ml, e_dl, e_pizza = [np.array(e) for e in embeddings]\n    sim_related = float(np.dot(e_ml, e_dl))\n    sim_unrelated = float(np.dot(e_ml, e_pizza))\n    assert (\n        sim_related > sim_unrelated\n    ), f\"Related sim ({sim_related:.4f}) should > unrelated sim ({sim_unrelated:.4f})\"\n\n\n# ===================================================================\n# Determinism\n# ===================================================================\n\n\ndef test_deterministic_output(embedding_engine):\n    \"\"\"Same input produces identical output across calls.\"\"\"\n    text = [\"Deterministic test\"]\n    emb1, _ = embedding_engine.embed(text)\n    emb2, _ = embedding_engine.embed(text)\n    cos = cosine_similarity(emb1[0], emb2[0])\n    assert cos > 0.9999, f\"Expected deterministic output, cosine={cos}\"\n\n\n# ===================================================================\n# Async embedding\n# ===================================================================\n\n\ndef test_async_embed(embedding_engine):\n    \"\"\"async_embed produces same result as sync embed.\"\"\"\n    text = [\"Async test\"]\n    sync_emb, sync_tokens = embedding_engine.embed(text)\n\n    loop = asyncio.new_event_loop()\n    try:\n        async_emb, async_tokens = loop.run_until_complete(embedding_engine.async_embed(text))\n    finally:\n        loop.close()\n\n    assert sync_tokens == async_tokens\n    cos = cosine_similarity(sync_emb[0], async_emb[0])\n    assert cos > 0.9999, f\"Async vs sync mismatch, cosine={cos}\"\n\n\n# ===================================================================\n# Edge cases\n# ===================================================================\n\n\ndef test_empty_string(embedding_engine):\n    \"\"\"Empty string should still produce a valid embedding for supported models.\"\"\"\n    embeddings, tokens = embedding_engine.embed([\"\"])\n    if embedding_engine.model_type == \"encoder\":\n        assert len(embeddings) == 1\n        assert len(embeddings[0]) > 0\n        assert tokens > 0\n    else:\n        assert len(embeddings) == 1\n        assert len(embeddings[0]) > 0\n        assert tokens > 0\n\n\n# ===================================================================\n# Long text handling (model-type dependent)\n# ===================================================================\n\n\ndef test_long_text_decoder_chunked_prefill(embedding_engine):\n    \"\"\"[Decoder only] Text >prefill_chunk_size triggers chunked prefill.\n    ~5000 tokens processed in 3 chunks. Result is unit-norm embedding.\"\"\"\n    if embedding_engine.model_type != \"decoder\":\n        pytest.skip(\"Chunked prefill is decoder-only\")\n    long_text = \"word \" * 5000\n    embeddings, tokens = embedding_engine.embed([long_text])\n    assert tokens > 2048, f\"Expected >2048 tokens to trigger chunking, got {tokens}\"\n    norm = float(np.linalg.norm(embeddings[0]))\n    assert abs(norm - 1.0) < 1e-3\n\n\ndef _get_encoder_tokens(embedding_engine, text):\n    \"\"\"Replicate encoder preprocessing: tokenize and add [CLS]/[SEP].\"\"\"\n    tokens = list(embedding_engine.tokenizer.encode(text))\n    if embedding_engine._cls_token_id is not None and (\n        len(tokens) == 0 or tokens[0] != embedding_engine._cls_token_id\n    ):\n        tokens = [embedding_engine._cls_token_id] + tokens\n    if embedding_engine._sep_token_id is not None and (\n        len(tokens) == 0 or tokens[-1] != embedding_engine._sep_token_id\n    ):\n        tokens = tokens + [embedding_engine._sep_token_id]\n    return tokens\n\n\ndef test_long_text_encoder_truncation(embedding_engine):  # pylint: disable=too-many-locals\n    \"\"\"[Encoder only] Text exceeding prefill_chunk_size is truncated.\n    Two texts with the same shared prefix but different suffixes beyond the\n    limit should produce identical embeddings, since the suffix is truncated\n    and the retained token prefixes are verified to be identical.\"\"\"\n    if embedding_engine.model_type != \"encoder\":\n        pytest.skip(\"Truncation test is encoder-only\")\n    prefill_chunk = embedding_engine._metadata.get(\"prefill_chunk_size\", 512)\n\n    # Dynamically construct input that exceeds prefill_chunk_size.\n    unit = \"machine learning is great \"\n    suffix_a = \" alpha beta gamma \" * 200\n    suffix_b = \" totally different ending \" * 200\n    unit_tokens = len(list(embedding_engine.tokenizer.encode(unit)))\n    repeats = max(1, prefill_chunk // max(unit_tokens, 1) + 64)\n\n    # Increase prefix length until both inputs exceed prefill_chunk_size\n    # and their truncated token prefixes are identical.\n    while True:\n        shared_prefix = unit * repeats\n        full_tokens_a = _get_encoder_tokens(embedding_engine, shared_prefix + suffix_a)\n        full_tokens_b = _get_encoder_tokens(embedding_engine, shared_prefix + suffix_b)\n        if (\n            len(full_tokens_a) > prefill_chunk\n            and len(full_tokens_b) > prefill_chunk\n            and full_tokens_a[:prefill_chunk] == full_tokens_b[:prefill_chunk]\n        ):\n            break\n        repeats += 64\n        assert repeats < 200000, \"Failed to construct truncation test inputs\"\n\n    text_a = shared_prefix + suffix_a\n    text_b = shared_prefix + suffix_b\n\n    emb_a, tokens_a = embedding_engine.embed([text_a])\n    emb_b, tokens_b = embedding_engine.embed([text_b])\n\n    # Verify truncation happened\n    assert (\n        tokens_a <= prefill_chunk\n    ), f\"Encoder should truncate to {prefill_chunk}, got {tokens_a} tokens\"\n    assert tokens_b <= prefill_chunk\n    # Both should be valid unit-norm embeddings\n    assert abs(float(np.linalg.norm(emb_a[0])) - 1.0) < 1e-3\n    assert abs(float(np.linalg.norm(emb_b[0])) - 1.0) < 1e-3\n\n    # Both truncated to identical token sequences → embeddings must match\n    cos = cosine_similarity(emb_a[0], emb_b[0])\n    assert cos > 0.999, f\"Same truncated tokens should match, cosine={cos:.6f}\"\n\n\ndef test_long_vs_short_semantic_quality(embedding_engine):\n    \"\"\"Long text should still capture semantic meaning correctly.\n    Decoder: chunked prefill preserves full context.\n    Encoder: truncation keeps most relevant prefix.\"\"\"\n    short_ml = \"Machine learning enables systems to learn from data\"\n    long_ml = (\n        \"Machine learning is a fascinating field of study. \" * 200\n        + \"It enables systems to learn from data.\"\n    )\n    pizza = \"I want to order a pepperoni pizza for dinner\"\n\n    embs, _ = embedding_engine.embed([short_ml, long_ml, pizza])\n    e_short, e_long, e_pizza = [np.array(e) for e in embs]\n\n    sim_same_topic = float(np.dot(e_short, e_long))\n    sim_different = float(np.dot(e_short, e_pizza))\n    assert (\n        sim_same_topic > sim_different\n    ), f\"Same topic ({sim_same_topic:.4f}) should > different ({sim_different:.4f})\"\n\n\ndef test_unicode_text(embedding_engine):\n    \"\"\"Unicode input is handled correctly.\"\"\"\n    texts = [\"Привет мир\", \"你好世界\", \"こんにちは世界\"]\n    embeddings, _ = embedding_engine.embed(texts)\n    assert len(embeddings) == 3\n    for emb in embeddings:\n        assert abs(float(np.linalg.norm(emb)) - 1.0) < 1e-4\n\n\n# ===================================================================\n# Standalone runner (like test_serve_engine.py)\n# ===================================================================\n\nif __name__ == \"__main__\":\n    _skip_if_no_model()\n    from mlc_llm.serve.embedding_engine import AsyncEmbeddingEngine\n\n    engine = AsyncEmbeddingEngine(\n        model=EMBEDDING_MODEL_DIR,\n        model_lib=EMBEDDING_MODEL_LIB,\n        device=\"auto\",\n    )\n    try:\n        test_engine_model_type(engine)\n        test_engine_pooling_strategy(engine)\n        test_single_text_shape(engine)\n        test_single_text_unit_norm(engine)\n        test_batch_count(engine)\n        test_batch_all_normalized(engine)\n        test_batch_consistent_dimension(engine)\n        test_cosine_similarity_ranking(engine)\n        test_deterministic_output(engine)\n        test_async_embed(engine)\n        test_empty_string(engine)\n        test_long_text_decoder_chunked_prefill(engine)\n        test_long_text_encoder_truncation(engine)\n        test_long_vs_short_semantic_quality(engine)\n        test_unicode_text(engine)\n        print(\"\\nAll embedding engine tests passed!\")\n    finally:\n        engine.terminate()\n"
  },
  {
    "path": "tests/python/serve/test_event_trace_recorder.py",
    "content": "# pylint: disable=missing-module-docstring,missing-function-docstring\nimport json\n\nimport pytest\n\nfrom mlc_llm.serve.event_trace_recorder import EventTraceRecorder\n\n# test category \"unittest\"\npytestmark = [pytest.mark.unittest]\n\n\ndef test_event_trace_recorder():\n    trace_recorder = EventTraceRecorder()\n    request_ids = [\"x\", \"y\"]\n    num_decode = 5\n\n    for request_id in request_ids:\n        trace_recorder.add_event(request_id, event=\"start tokenization\")\n        trace_recorder.add_event(request_id, event=\"finish tokenization\")\n        trace_recorder.add_event(request_id, event=\"add request\")\n        trace_recorder.add_event(request_id, event=\"start embed\")\n        trace_recorder.add_event(request_id, event=\"finish embed\")\n        trace_recorder.add_event(request_id, event=\"start prefill\")\n        trace_recorder.add_event(request_id, event=\"finish prefill\")\n\n    for _ in range(num_decode):\n        for request_id in request_ids:\n            trace_recorder.add_event(request_id, event=\"start decode\")\n            trace_recorder.add_event(request_id, event=\"finish decode\")\n    for request_id in request_ids:\n        trace_recorder.add_event(request_id, event=\"start detokenization\")\n        trace_recorder.add_event(request_id, event=\"finish detokenization\")\n\n    events = json.loads(trace_recorder.dump_json())\n    decode_count = {}\n    for event in events:\n        request_id = event[\"tid\"]\n        if event[\"name\"].startswith(\"decode\"):\n            if request_id not in decode_count:\n                decode_count[request_id] = 1\n            else:\n                decode_count[request_id] += 1\n\n    for _, decode_cnt in decode_count.items():\n        assert decode_cnt == num_decode * 2, decode_cnt\n\n\nif __name__ == \"__main__\":\n    test_event_trace_recorder()\n"
  },
  {
    "path": "tests/python/serve/test_radix_tree.py",
    "content": "import pytest\n\nfrom mlc_llm.serve import PagedRadixTree\n\n# category \"runtime_module\"\npytestmark = [pytest.mark.unittest]\n\n\ndef test_add():\n    prt = PagedRadixTree()\n    prt.add(0)\n    assert list(prt.get(0)) == []\n    prt.add(1)\n    assert list(prt.get(1)) == []\n\n\ndef test_remove():\n    prt = PagedRadixTree()\n    capacity = prt.free_capacity()\n    prt.add(0)\n    prt.remove(0)\n    prt.add(0)\n    prt.extend(0, [1 for _ in range(200)])\n    prt.remove(0)\n    assert prt.free_capacity() == capacity\n\n    prt.add(1)\n    prt.extend(1, [1 for _ in range(200)])\n    capacity = prt.free_capacity()\n    prt.add(2)\n    prt.extend(2, [1 for _ in range(100)] + [2 for _ in range(100)])\n    prt.remove(2)\n    assert prt.free_capacity() == capacity\n\n    prt.add(3)\n    prt.extend(3, [1 for _ in range(200)])\n    prt.remove(3)\n    assert prt.free_capacity() == capacity\n\n    prt.add(4)\n    prt.add(5)\n    prt.add(6)\n    assert prt.free_capacity() == capacity\n    prt.remove(4)\n    assert prt.free_capacity() == capacity\n    prt.remove(5)\n    assert prt.free_capacity() == capacity\n    prt.remove(6)\n    assert prt.free_capacity() == capacity\n\n\ndef test_extend():\n    prt = PagedRadixTree()\n    L = prt.free_capacity() // 64\n    H = L // 2\n    Q = L // 4\n    seq_id = 0\n    for start_pos in [0, H, L, L + H]:\n        for length in [Q, L - H, L, 2 * L - H, 2 * L]:\n            prt.add(seq_id)\n            if start_pos:\n                tokens_1 = [seq_id for _ in range(start_pos)]\n                prt.extend(seq_id, tokens_1)\n                assert list(prt.get(seq_id)) == tokens_1\n            else:\n                tokens_1 = []\n            tokens_2 = [seq_id for _ in range(length)]\n            prt.extend(seq_id, tokens_2)\n            assert list(prt.get(seq_id)) == tokens_1 + tokens_2\n            seq_id += 1\n\n\ndef test_fork():\n    prt = PagedRadixTree()\n    L = prt.free_capacity() // 64\n    H = L // 2\n    Q = L // 4\n    seq_id = 0\n    length_list = [Q, H, L, L + Q, L + H, L * 2]\n    for p_idx in range(1, len(length_list)):\n        for c_idx in range(0, p_idx + 1):\n            prt.add(seq_id)\n            tokens = [seq_id for _ in range(length_list[p_idx])]\n            prt.extend(seq_id, tokens)\n            prt.fork(seq_id + 1, seq_id, length_list[c_idx])\n            assert list(prt.get(seq_id + 1)) == tokens[: length_list[c_idx]]\n            seq_id += 2\n\n\ndef test_fork_2():\n    prt = PagedRadixTree()\n    prt.add(0)\n    prt.extend(0, [0, 1, 2, 3])\n    prt.fork(1, 0, 3)\n    prt.extend(1, [4])\n    prt.fork(2, 0, 3)\n    prt.extend(2, [5])\n    assert prt.match([0, 1, 2, 4]) == (4, (1,))\n    assert prt.match([0, 1, 2, 5]) == (4, (2,))\n\n\ndef test_rollback():\n    prt = PagedRadixTree()\n    L = prt.free_capacity() // 64\n    H = L // 2\n    Q = L // 4\n    seq_id = 0\n    for start_pos in [H, L, L + H, 2 * L, 3 * L + H]:\n        for length in [Q, H, L + Q, 2 * L, 2 * L + Q]:\n            if length > start_pos:\n                continue\n            prt.add(seq_id)\n            tokens = [seq_id for _ in range(start_pos)]\n            prt.extend(seq_id, tokens)\n            prt.rollback(seq_id, length)\n            assert list(prt.get(seq_id)) == tokens[:-length]\n            seq_id += 1\n\n    for start_pos in [H, L, L + H, 2 * L, 3 * L + H]:\n        for length in [Q, H, L + Q, 2 * L, 2 * L + Q]:\n            if length > start_pos:\n                continue\n            prt.add(seq_id)\n            tokens = [seq_id for _ in range(start_pos)]\n            prt.extend(seq_id, tokens)\n            prt.fork(seq_id + 1, seq_id, start_pos)\n            prt.rollback(seq_id + 1, length)\n            assert list(prt.get(seq_id + 1)) == tokens[:-length]\n            seq_id += 2\n\n\nif __name__ == \"__main__\":\n    test_add()\n    test_remove()\n    test_extend()\n    test_fork()\n    test_fork_2()\n    test_rollback()\n"
  },
  {
    "path": "tests/python/serve/test_serve_async_engine.py",
    "content": "# pylint: disable=chained-comparison,line-too-long,missing-docstring,\n# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable\nimport asyncio\nfrom typing import List\n\nfrom mlc_llm.protocol.generation_config import GenerationConfig\nfrom mlc_llm.serve import AsyncMLCEngine, EngineConfig\nfrom mlc_llm.testing import require_test_model\n\nprompts = [\n    \"What is the meaning of life?\",\n    \"Introduce the history of Pittsburgh to me. Please elaborate in detail.\",\n    \"Write a three-day Seattle travel plan. Please elaborate in detail.\",\n    \"What is Alaska famous of? Please elaborate in detail.\",\n    \"What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.\",\n    \"What are the necessary components to assemble a desktop computer? Please elaborate in detail.\",\n    \"Why is Vitamin D important to human beings? Please elaborate in detail.\",\n    \"Where is milk tea originated from? Please elaborate in detail.\",\n    \"Where is the southernmost place in United States? Please elaborate in detail.\",\n    \"Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.\",\n]\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q4f16_1-MLC\")\nasync def test_engine_generate(model: str):\n    # Create engine\n    async_engine = AsyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(max_total_sequence_length=4096),\n    )\n\n    num_requests = 10\n    max_tokens = 256\n    generation_cfg = GenerationConfig(max_tokens=max_tokens, n=7)\n\n    output_texts: List[List[str]] = [\n        [\"\" for _ in range(generation_cfg.n)] for _ in range(num_requests)\n    ]\n\n    async def generate_task(\n        async_engine: AsyncMLCEngine,\n        prompt: str,\n        generation_cfg: GenerationConfig,\n        request_id: str,\n    ):\n        print(f\"generate task for request {request_id}\")\n        rid = int(request_id)\n        async for delta_outputs in async_engine._generate(\n            prompt, generation_cfg, request_id=request_id\n        ):\n            if len(delta_outputs) == generation_cfg.n:\n                for i, delta_output in enumerate(delta_outputs):\n                    output_texts[rid][i] += delta_output.delta_text\n            else:\n                assert len(delta_outputs) == 1\n                assert len(delta_outputs[0].request_final_usage_json_str) != 0\n\n    tasks = [\n        asyncio.create_task(\n            generate_task(async_engine, prompts[i], generation_cfg, request_id=str(i))\n        )\n        for i in range(num_requests)\n    ]\n\n    await asyncio.gather(*tasks)\n\n    # Print output.\n    print(\"All finished\")\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n    async_engine.terminate()\n    del async_engine\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q4f16_1-MLC\")\nasync def test_chat_completion(model: str):\n    # Create engine\n    async_engine = AsyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(max_total_sequence_length=4096),\n    )\n\n    num_requests = 2\n    max_tokens = 32\n    n = 1\n    output_texts: List[List[str]] = [[\"\" for _ in range(n)] for _ in range(num_requests)]\n\n    async def generate_task(prompt: str, request_id: str):\n        print(f\"generate chat completion task for request {request_id}\")\n        rid = int(request_id)\n        async for response in await async_engine.chat.completions.create(\n            messages=[{\"role\": \"user\", \"content\": prompt}],\n            model=model,\n            max_tokens=max_tokens,\n            n=n,\n            request_id=request_id,\n            stream=True,\n        ):\n            for choice in response.choices:\n                assert choice.delta.role == \"assistant\"\n                assert isinstance(choice.delta.content, str)\n                output_texts[rid][choice.index] += choice.delta.content\n\n    tasks = [\n        asyncio.create_task(generate_task(prompts[i], request_id=str(i)))\n        for i in range(num_requests)\n    ]\n\n    await asyncio.gather(*tasks)\n\n    # Print output.\n    print(\"Chat completion all finished\")\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n    async_engine.terminate()\n    del async_engine\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q4f16_1-MLC\")\nasync def test_chat_completion_non_stream(model: str):\n    # Create engine\n    async_engine = AsyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(max_total_sequence_length=4096),\n    )\n\n    num_requests = 2\n    max_tokens = 32\n    n = 1\n    output_texts: List[List[str]] = [[\"\" for _ in range(n)] for _ in range(num_requests)]\n\n    async def generate_task(prompt: str, request_id: str):\n        print(f\"generate chat completion task for request {request_id}\")\n        rid = int(request_id)\n        response = await async_engine.chat.completions.create(\n            messages=[{\"role\": \"user\", \"content\": prompt}],\n            model=model,\n            max_tokens=max_tokens,\n            n=n,\n            request_id=request_id,\n        )\n        for choice in response.choices:\n            assert choice.message.role == \"assistant\"\n            assert isinstance(choice.message.content, str)\n            output_texts[rid][choice.index] += choice.message.content\n\n    tasks = [\n        asyncio.create_task(generate_task(prompts[i], request_id=str(i)))\n        for i in range(num_requests)\n    ]\n\n    await asyncio.gather(*tasks)\n\n    # Print output.\n    print(\"Chat completion all finished\")\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n    async_engine.terminate()\n    del async_engine\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\nasync def test_completion(model: str):\n    # Create engine\n    async_engine = AsyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(max_total_sequence_length=4096),\n    )\n\n    num_requests = 2\n    max_tokens = 128\n    n = 1\n    output_texts: List[List[str]] = [[\"\" for _ in range(n)] for _ in range(num_requests)]\n\n    async def generate_task(prompt: str, request_id: str):\n        print(f\"generate completion task for request {request_id}\")\n        rid = int(request_id)\n        async for response in await async_engine.completions.create(\n            prompt=prompt,\n            model=model,\n            max_tokens=max_tokens,\n            n=n,\n            request_id=request_id,\n            stream=True,\n            extra_body={\"debug_config\": {\"ignore_eos\": True}},\n        ):\n            for choice in response.choices:\n                output_texts[rid][choice.index] += choice.text\n\n    tasks = [\n        asyncio.create_task(generate_task(prompts[i], request_id=str(i)))\n        for i in range(num_requests)\n    ]\n\n    await asyncio.gather(*tasks)\n\n    # Print output.\n    print(\"Completion all finished\")\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n    async_engine.terminate()\n    del async_engine\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\nasync def test_completion_non_stream(model: str):\n    # Create engine\n    async_engine = AsyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(max_total_sequence_length=4096),\n    )\n\n    num_requests = 2\n    max_tokens = 128\n    n = 1\n    output_texts: List[List[str]] = [[\"\" for _ in range(n)] for _ in range(num_requests)]\n\n    async def generate_task(prompt: str, request_id: str):\n        print(f\"generate completion task for request {request_id}\")\n        rid = int(request_id)\n        response = await async_engine.completions.create(\n            prompt=prompt,\n            model=model,\n            max_tokens=max_tokens,\n            n=n,\n            request_id=request_id,\n            extra_body={\"debug_config\": {\"ignore_eos\": True}},\n        )\n        for choice in response.choices:\n            output_texts[rid][choice.index] += choice.text\n\n    tasks = [\n        asyncio.create_task(generate_task(prompts[i], request_id=str(i)))\n        for i in range(num_requests)\n    ]\n\n    await asyncio.gather(*tasks)\n\n    # Print output.\n    print(\"Completion all finished\")\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n    async_engine.terminate()\n    del async_engine\n\n\nif __name__ == \"__main__\":\n    asyncio.run(test_engine_generate())\n    asyncio.run(test_chat_completion())\n    asyncio.run(test_chat_completion_non_stream())\n    asyncio.run(test_completion())\n    asyncio.run(test_completion_non_stream())\n"
  },
  {
    "path": "tests/python/serve/test_serve_async_engine_spec.py",
    "content": "# pylint: disable=chained-comparison,line-too-long,missing-docstring,\n# pylint: disable=too-many-arguments,too-many-locals\nimport asyncio\nfrom typing import List\n\nfrom mlc_llm.protocol.generation_config import GenerationConfig\nfrom mlc_llm.serve import AsyncMLCEngine, EngineConfig\nfrom mlc_llm.testing import require_test_model\n\nprompts = [\n    \"What is the meaning of life?\",\n    \"Introduce the history of Pittsburgh to me. Please elaborate in detail.\",\n    \"Write a three-day Seattle travel plan. Please elaborate in detail.\",\n    \"What is Alaska famous of? Please elaborate in detail.\",\n    \"What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.\",\n    \"What are the necessary components to assemble a desktop computer? Please elaborate in detail.\",\n    \"Why is Vitamin D important to human beings? Please elaborate in detail.\",\n    \"Where is milk tea originated from? Please elaborate in detail.\",\n    \"Where is the southernmost place in United States? Please elaborate in detail.\",\n    \"Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.\",\n]\n\n\n@require_test_model(\n    \"Llama-2-7b-chat-hf-q0f16-MLC\",\n    \"Llama-2-7b-chat-hf-q4f16_1-MLC\",\n)\nasync def test_engine_generate(model: str, small_model: str):\n    # Create engine\n    async_engine = AsyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            additional_models=[small_model],\n            speculative_mode=\"small_draft\",\n        ),\n    )\n\n    num_requests = 10\n    max_tokens = 256\n    generation_cfg = GenerationConfig(max_tokens=max_tokens)\n\n    output_texts: List[List[str]] = [\n        [\"\" for _ in range(generation_cfg.n)] for _ in range(num_requests)\n    ]\n\n    async def generate_task(\n        async_engine: AsyncMLCEngine,\n        prompt: str,\n        generation_cfg: GenerationConfig,\n        request_id: str,\n    ):\n        print(f\"generate task for request {request_id}\")\n        rid = int(request_id)\n        async for delta_outputs in async_engine._generate(\n            prompt, generation_cfg, request_id=request_id\n        ):\n            assert len(delta_outputs) == generation_cfg.n\n            for i, delta_output in enumerate(delta_outputs):\n                output_texts[rid][i] += delta_output.delta_text\n\n    tasks = [\n        asyncio.create_task(\n            generate_task(async_engine, prompts[i], generation_cfg, request_id=str(i))\n        )\n        for i in range(num_requests)\n    ]\n\n    await asyncio.gather(*tasks)\n\n    # Print output.\n    print(\"All finished\")\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n    async_engine.terminate()\n    del async_engine\n\n\nif __name__ == \"__main__\":\n    asyncio.run(test_engine_generate())\n"
  },
  {
    "path": "tests/python/serve/test_serve_engine.py",
    "content": "# pylint: disable=chained-comparison,line-too-long,missing-docstring,\n# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable\nfrom typing import List\n\nfrom mlc_llm.protocol.generation_config import GenerationConfig\nfrom mlc_llm.serve import EngineConfig, MLCEngine\nfrom mlc_llm.testing import require_test_model\n\nprompts = [\n    \"What is the meaning of life?\",\n    \"Introduce the history of Pittsburgh to me. Please elaborate in detail.\",\n    \"Write a three-day Seattle travel plan. Please elaborate in detail.\",\n    \"What is Alaska famous of? Please elaborate in detail.\",\n    \"What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.\",\n    \"What are the necessary components to assemble a desktop computer? Please elaborate in detail.\",\n    \"Why is Vitamin D important to human beings? Please elaborate in detail.\",\n    \"Where is milk tea originated from? Please elaborate in detail.\",\n    \"Where is the southernmost place in United States? Please elaborate in detail.\",\n    \"Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.\",\n]\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\ndef test_engine_generate(model: str):\n    # Create engine\n    engine = MLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_total_sequence_length=4096,\n        ),\n    )\n\n    num_requests = 10\n    max_tokens = 256\n    generation_cfg = GenerationConfig(max_tokens=max_tokens, n=7)\n\n    output_texts: List[List[str]] = [\n        [\"\" for _ in range(generation_cfg.n)] for _ in range(num_requests)\n    ]\n    for rid in range(num_requests):\n        print(f\"generating for request {rid}\")\n        for delta_outputs in engine._generate(prompts[rid], generation_cfg, request_id=str(rid)):\n            assert len(delta_outputs) == generation_cfg.n\n            for i, delta_output in enumerate(delta_outputs):\n                output_texts[rid][i] += delta_output.delta_text\n\n    # Print output.\n    print(\"All finished\")\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n    engine.terminate()\n    del engine\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\ndef test_chat_completion(model: str):\n    # Create engine\n    engine = MLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_total_sequence_length=4096,\n        ),\n    )\n\n    num_requests = 2\n    max_tokens = 64\n    n = 2\n    output_texts: List[List[str]] = [[\"\" for _ in range(n)] for _ in range(num_requests)]\n\n    for rid in range(num_requests):\n        print(f\"chat completion for request {rid}\")\n        for response in engine.chat.completions.create(\n            messages=[{\"role\": \"user\", \"content\": prompts[rid]}],\n            model=model,\n            max_tokens=max_tokens,\n            n=n,\n            request_id=str(rid),\n            stream=True,\n        ):\n            for choice in response.choices:\n                assert choice.delta.role == \"assistant\"\n                assert isinstance(choice.delta.content, str)\n                output_texts[rid][choice.index] += choice.delta.content\n\n    # Print output.\n    print(\"Chat completion all finished\")\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n    engine.terminate()\n    del engine\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\ndef test_chat_completion_non_stream(model: str):\n    # Create engine\n    engine = MLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_total_sequence_length=4096,\n        ),\n    )\n\n    num_requests = 2\n    max_tokens = 64\n    n = 2\n    output_texts: List[List[str]] = [[\"\" for _ in range(n)] for _ in range(num_requests)]\n\n    for rid in range(num_requests):\n        print(f\"chat completion for request {rid}\")\n        response = engine.chat.completions.create(\n            messages=[{\"role\": \"user\", \"content\": prompts[rid]}],\n            model=model,\n            max_tokens=max_tokens,\n            n=n,\n            request_id=str(rid),\n        )\n        for choice in response.choices:\n            assert choice.message.role == \"assistant\"\n            assert isinstance(choice.message.content, str)\n            output_texts[rid][choice.index] += choice.message.content\n\n    # Print output.\n    print(\"Chat completion all finished\")\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n    engine.terminate()\n    del engine\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\ndef test_completion(model: str):\n    # Create engine\n    engine = MLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_total_sequence_length=4096,\n        ),\n    )\n\n    num_requests = 2\n    max_tokens = 128\n    n = 1\n    output_texts: List[List[str]] = [[\"\" for _ in range(n)] for _ in range(num_requests)]\n\n    for rid in range(num_requests):\n        print(f\"completion for request {rid}\")\n        for response in engine.completions.create(\n            prompt=prompts[rid],\n            model=model,\n            max_tokens=max_tokens,\n            n=n,\n            request_id=str(rid),\n            stream=True,\n            extra_body={\"debug_config\": {\"ignore_eos\": True}},\n        ):\n            for choice in response.choices:\n                output_texts[rid][choice.index] += choice.text\n\n    # Print output.\n    print(\"Completion all finished\")\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n    engine.terminate()\n    del engine\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\ndef test_completion_non_stream(model: str):\n    # Create engine\n    engine = MLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_total_sequence_length=4096,\n        ),\n    )\n\n    num_requests = 2\n    max_tokens = 128\n    n = 1\n    output_texts: List[List[str]] = [[\"\" for _ in range(n)] for _ in range(num_requests)]\n\n    for rid in range(num_requests):\n        print(f\"completion for request {rid}\")\n        response = engine.completions.create(\n            prompt=prompts[rid],\n            model=model,\n            max_tokens=max_tokens,\n            n=n,\n            request_id=str(rid),\n            extra_body={\"debug_config\": {\"ignore_eos\": True}},\n        )\n        for choice in response.choices:\n            output_texts[rid][choice.index] += choice.text\n\n    # Print output.\n    print(\"Completion all finished\")\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n    engine.terminate()\n    del engine\n\n\nif __name__ == \"__main__\":\n    test_engine_generate()\n    test_chat_completion()\n    test_chat_completion_non_stream()\n    test_completion()\n    test_completion_non_stream()\n"
  },
  {
    "path": "tests/python/serve/test_serve_engine_grammar.py",
    "content": "# pylint: disable=chained-comparison,line-too-long,missing-docstring,\n# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable\nimport asyncio\nimport json\nimport random\nfrom typing import Dict, List, Literal\n\nfrom pydantic import BaseModel\n\nfrom mlc_llm.protocol.debug_protocol import DebugConfig\nfrom mlc_llm.protocol.openai_api_protocol import ChatCompletionResponse\nfrom mlc_llm.serve import AsyncMLCEngine, MLCEngine\nfrom mlc_llm.testing import require_test_model\n\nLLAMA_2_MODEL = \"Llama-2-7b-chat-hf-q4f16_1-MLC\"\nLLAMA_3_MODEL = \"Meta-Llama-3-8B-Instruct-q4f16_1-MLC\"\n\n\n@require_test_model(LLAMA_3_MODEL)\ndef test_batch_generation_with_grammar(model: str):\n    # Engine\n    engine = MLCEngine(model=model, mode=\"server\")\n\n    # Inputs\n    system_prompt = \"You are a helpful assistant. Always respond only with json.\"\n    prompts_list = [\n        \"Generate a JSON string containing 20 objects:\",\n        \"Generate a JSON containing a non-empty list:\",\n        \"Generate a JSON with 5 elements:\",\n        \"Generate a JSON with a number list, counting from 1 to 20:\",\n    ]\n\n    repeat = 3\n    top_p = 0.9\n    temperature = 0.6\n    max_tokens = 4096\n\n    # non-json output\n    responses_text: List[ChatCompletionResponse] = []\n    for _ in range(repeat):\n        for p in prompts_list:\n            print(f\"Start generation task for request {len(responses_text)}\")\n            responses_text.append(\n                engine.chat.completions.create(\n                    messages=[\n                        {\"role\": \"system\", \"content\": system_prompt},\n                        {\"role\": \"user\", \"content\": p},\n                    ],\n                    response_format={\"type\": \"text\"},\n                    top_p=top_p,\n                    temperature=temperature,\n                    max_tokens=max_tokens,\n                    seed=random.randint(0, 1 << 30),\n                    extra_body={\"debug_config\": DebugConfig(grammar_execution_mode=\"constraint\")},\n                )\n            )\n\n    print(\"Text output\")\n    for req_id, response in enumerate(responses_text):\n        prompt = prompts_list[req_id % len(prompts_list)]\n        output = response.choices[0].message.content\n        print(f\"Prompt {req_id}: {prompt}\")\n        print(f\"Output {req_id}: {output}\\n\")\n\n    # json output\n    responses_json: List[ChatCompletionResponse] = []\n    for _ in range(repeat):\n        for p in prompts_list:\n            print(f\"Start generation task for request {len(responses_json)}\")\n            responses_json.append(\n                engine.chat.completions.create(\n                    messages=[\n                        {\"role\": \"system\", \"content\": system_prompt},\n                        {\"role\": \"user\", \"content\": p},\n                    ],\n                    response_format={\"type\": \"json_object\"},\n                    top_p=top_p,\n                    temperature=temperature,\n                    seed=random.randint(0, 1 << 30),\n                )\n            )\n\n    print(\"JSON output\")\n    for req_id, response in enumerate(responses_json):\n        prompt = prompts_list[req_id % len(prompts_list)]\n        output = str(response.choices[0].message.content)\n        print(f\"Prompt {req_id}: {prompt}\")\n        print(f\"Output {req_id}: {output}\\n\")\n        json.loads(output)\n\n    print(\"Engine metrics:\", engine.metrics())\n\n    engine.terminate()\n\n\n@require_test_model(LLAMA_3_MODEL)\ndef test_batch_generation_with_schema(model: str):\n    # Create engine\n    engine = MLCEngine(model=model, mode=\"server\")\n\n    class Product(BaseModel):\n        product_id: int\n        is_available: bool\n        price: float\n        is_featured: Literal[True]\n        category: Literal[\"Electronics\", \"Clothing\", \"Food\"]\n        tags: List[str]\n        stock: Dict[str, int]\n\n    schema_str = json.dumps(Product.model_json_schema())\n\n    system_prompt = (\n        \"You are a helpful assistant. Always respond only with JSON based on the \"\n        f\"following JSON schema: {schema_str}.\"\n    )\n    prompt = \"Generate a JSON that describes the product according to the given JSON schema.\"\n\n    repeat = 8\n    top_p = 0.9\n    temperature = 0.6\n    max_tokens = 4096\n\n    # non-json output\n    responses_text: List[ChatCompletionResponse] = []\n    for i in range(repeat):\n        print(f\"Start generation task for request {i}\")\n        responses_text.append(\n            engine.chat.completions.create(\n                messages=[\n                    {\"role\": \"system\", \"content\": system_prompt},\n                    {\"role\": \"user\", \"content\": prompt},\n                ],\n                response_format={\"type\": \"text\"},\n                top_p=top_p,\n                temperature=temperature,\n                max_tokens=max_tokens,\n                seed=random.randint(0, 1 << 30),\n                extra_body={\"debug_config\": DebugConfig(grammar_execution_mode=\"constraint\")},\n            )\n        )\n\n    print(\"Text output\")\n    for req_id, response in enumerate(responses_text):\n        output = response.choices[0].message.content\n        print(f\"Prompt {req_id}: {prompt}\")\n        print(f\"Output {req_id}: {output}\\n\")\n\n    # json output without schema\n    responses_json: List[ChatCompletionResponse] = []\n    for i in range(repeat):\n        print(f\"Start generation task for request {i}\")\n        responses_json.append(\n            engine.chat.completions.create(\n                messages=[\n                    {\"role\": \"system\", \"content\": system_prompt},\n                    {\"role\": \"user\", \"content\": prompt},\n                ],\n                response_format={\"type\": \"json_object\"},\n                top_p=top_p,\n                temperature=temperature,\n                max_tokens=max_tokens,\n                seed=random.randint(0, 1 << 30),\n                extra_body={\"debug_config\": DebugConfig(grammar_execution_mode=\"constraint\")},\n            )\n        )\n\n    print(\"JSON output\")\n    for req_id, response in enumerate(responses_json):\n        output = response.choices[0].message.content\n        print(f\"Prompt {req_id}: {prompt}\")\n        print(f\"Output {req_id}: {output}\\n\")\n\n    # json output with schema\n    responses_schema: List[ChatCompletionResponse] = []\n    for i in range(repeat):\n        print(f\"Start generation task for request {i}\")\n        responses_schema.append(\n            engine.chat.completions.create(\n                messages=[\n                    {\"role\": \"system\", \"content\": system_prompt},\n                    {\"role\": \"user\", \"content\": prompt},\n                ],\n                response_format={\"type\": \"json_object\", \"schema\": schema_str},\n                top_p=top_p,\n                temperature=temperature,\n                max_tokens=max_tokens,\n                seed=random.randint(0, 1 << 30),\n                extra_body={\"debug_config\": DebugConfig(grammar_execution_mode=\"constraint\")},\n            )\n        )\n\n    print(\"JSON Schema output\")\n    for req_id, response in enumerate(responses_schema):\n        output = response.choices[0].message.content\n        print(f\"Prompt {req_id}: {prompt}\")\n        print(f\"Output {req_id}: {output}\\n\")\n\n    print(\"Engine metrics:\", engine.metrics())\n\n    engine.terminate()\n\n\n@require_test_model(LLAMA_3_MODEL)\ndef test_batch_generation_jump_forward(model: str, jump_forward: bool = True, repeat: int = 1):\n    # Create engine\n    engine = MLCEngine(model=model, mode=\"server\")\n\n    class Product(BaseModel):\n        product_id: int\n        is_available: bool\n        price: float\n        is_featured: Literal[True]\n        category: Literal[\"Electronics\", \"Clothing\", \"Food\"]\n        tags: List[str]\n        stock: Dict[str, int]\n\n    schema_str = json.dumps(Product.model_json_schema())\n\n    system_prompt = (\n        \"You are a helpful assistant. Always respond only with JSON based on the \"\n        f\"following JSON schema: {schema_str}.\"\n    )\n    prompt = \"Generate a JSON that describes the product according to the given JSON schema.\"\n\n    top_p = 0.9\n    temperature = 0.6\n    max_tokens = 4096\n    grammar_execution_mode = \"jump_forward\" if jump_forward else \"constraint\"\n\n    # json output with schema\n    responses: List[ChatCompletionResponse] = []\n    for i in range(repeat):\n        print(f\"Start generation task for request {i}\")\n        responses.append(\n            engine.chat.completions.create(\n                messages=[\n                    {\"role\": \"system\", \"content\": system_prompt},\n                    {\"role\": \"user\", \"content\": prompt},\n                ],\n                response_format={\"type\": \"json_object\", \"schema\": schema_str},\n                top_p=top_p,\n                temperature=temperature,\n                max_tokens=max_tokens,\n                seed=random.randint(0, 1 << 30),\n                extra_body={\n                    \"debug_config\": DebugConfig(grammar_execution_mode=grammar_execution_mode)\n                },\n            )\n        )\n\n    print(f\"Jump forward: {jump_forward}, Repeat: {repeat}\")\n    for req_id, response in enumerate(responses):\n        output = response.choices[0].message.content\n        print(f\"Prompt {req_id}: {prompt}\")\n        print(f\"Output {req_id}: {output}\\n\")\n\n    print(\"Engine metrics:\", engine.metrics())\n\n    engine.terminate()\n\n\n@require_test_model(LLAMA_3_MODEL)\nasync def run_async_engine(\n    model: str,\n    mode: Literal[\"text\", \"json\", \"schema\"] = \"schema\",\n    jump_forward: bool = True,\n    num_requests: int = 8,\n):\n    # Create engine\n    async_engine = AsyncMLCEngine(model=model, mode=\"server\")\n\n    class Product(BaseModel):\n        product_id: int\n        is_available: bool\n        price: float\n        is_featured: Literal[True]\n        category: Literal[\"Electronics\", \"Clothing\", \"Food\"]\n        tags: List[str]\n        stock: Dict[str, int]\n\n    schema_str = json.dumps(Product.model_json_schema())\n\n    if mode == \"text\":\n        response_format = {\"type\": \"text\"}\n    elif mode == \"json\":\n        response_format = {\"type\": \"json_object\"}\n    elif mode == \"schema\":\n        response_format = {\"type\": \"json_object\", \"schema\": schema_str}\n\n    system_prompt = (\n        \"You are a helpful assistant. Always respond only with JSON based on the \"\n        f\"following JSON schema: {schema_str}.\"\n    )\n    prompt = \"Generate a JSON that describes the product according to the given JSON schema.\"\n\n    top_p = 0.9\n    temperature = 0.6\n    max_tokens = 4096\n    grammar_execution_mode = \"jump_forward\" if jump_forward else \"constraint\"\n\n    responses = [\"\" for _ in range(num_requests)]\n\n    async def generate_task(prompt: str, request_id: str):\n        print(f\"Start generation task for request {request_id}\")\n        rid = int(request_id)\n        async for response in await async_engine.chat.completions.create(\n            messages=[\n                {\"role\": \"system\", \"content\": system_prompt},\n                {\"role\": \"user\", \"content\": prompt},\n            ],\n            response_format=response_format,\n            top_p=top_p,\n            temperature=temperature,\n            max_tokens=max_tokens,\n            seed=random.randint(0, 1 << 30),\n            stream=True,\n            extra_body={\"debug_config\": DebugConfig(grammar_execution_mode=grammar_execution_mode)},\n        ):\n            assert len(response.choices) == 1\n            choice = response.choices[0]\n            assert choice.delta.role == \"assistant\"\n            assert isinstance(choice.delta.content, str)\n            responses[rid] += choice.delta.content\n\n    tasks = [\n        asyncio.create_task(generate_task(prompt, request_id=str(i))) for i in range(num_requests)\n    ]\n\n    await asyncio.gather(*tasks)\n\n    print(f\"Mode: {mode}, Jump forward: {jump_forward}, Num requests: {num_requests}\")\n    for req_id, output in enumerate(responses):\n        print(f\"Prompt {req_id}: {prompt}\")\n        print(f\"Output {req_id}: {output}\\n\")\n\n    print(\"Engine metrics:\", await async_engine.metrics())\n\n    async_engine.terminate()\n    del async_engine\n\n\ndef test_async_engine(\n    mode: Literal[\"text\", \"json\", \"schema\"] = \"schema\",\n    jump_forward: bool = True,\n    num_requests: int = 8,\n):\n    asyncio.run(run_async_engine(mode, jump_forward, num_requests))\n\n\nif __name__ == \"__main__\":\n    test_batch_generation_with_grammar()\n    test_batch_generation_with_schema()\n    test_batch_generation_jump_forward(False)\n    test_batch_generation_jump_forward(True)\n    test_async_engine(\"schema\", False, 1)\n    test_async_engine(\"schema\", True, 1)\n    test_async_engine(\"schema\", False, 8)\n    test_async_engine(\"schema\", True, 8)\n"
  },
  {
    "path": "tests/python/serve/test_serve_engine_image.py",
    "content": "import json\nfrom pathlib import Path\n\nfrom mlc_llm.protocol.generation_config import GenerationConfig\nfrom mlc_llm.serve import data\nfrom mlc_llm.serve.sync_engine import EngineConfig, SyncMLCEngine\n\n\ndef get_test_image(config) -> data.ImageData:\n    return data.ImageData.from_url(\"https://llava-vl.github.io/static/images/view.jpg\", config)\n\n\ndef test_engine_generate():\n    # Create engine\n    model = \"dist/llava-1.5-7b-hf-q4f16_1-MLC/params\"\n    model_lib = \"dist/llava-1.5-7b-hf-q4f16_1-MLC/llava-1.5-7b-hf-q4f16_1-MLC.so\"\n    engine = SyncMLCEngine(\n        model=model,\n        model_lib=model_lib,\n        mode=\"server\",\n        engine_config=EngineConfig(max_total_sequence_length=4096),\n    )\n    max_tokens = 256\n\n    with open(Path(model) / \"mlc-chat-config.json\", \"r\", encoding=\"utf-8\") as file:\n        model_config = json.load(file)\n\n    prompts = [\n        [\n            data.TextData(\"USER: \"),\n            get_test_image(model_config),\n            data.TextData(\"\\nWhat does this image represent? ASSISTANT:\"),\n        ],\n        [\n            data.TextData(\"USER: \"),\n            get_test_image(model_config),\n            data.TextData(\"\\nIs there a dog in this image? ASSISTANT:\"),\n        ],\n        [data.TextData(\"USER: What is the meaning of life? ASSISTANT:\")],\n    ]\n\n    output_texts, _ = engine.generate(\n        prompts, GenerationConfig(max_tokens=max_tokens, stop_token_ids=[2])\n    )\n\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n\nif __name__ == \"__main__\":\n    test_engine_generate()\n"
  },
  {
    "path": "tests/python/serve/test_serve_engine_mock.py",
    "content": "\"\"\"Mock testing engine I/O conventions\n\nMock test only can help checking the overall input\noutput processing options are passed correctly\n\"\"\"\n\nimport pytest\nimport tvm\n\nfrom mlc_llm.serve import MLCEngine\nfrom mlc_llm.testing import require_test_model\n\n# test category \"unittest\"\npytestmark = [pytest.mark.unittest]\n\n\n# NOTE: we only need tokenizers in folder\n# launch time of mock test is fast so we can put it in unittest\n@require_test_model(\"Llama-3-8B-Instruct-q4f16_1-MLC\")\ndef test_completion_api(model: str):\n    engine = MLCEngine(model, tvm.cpu(), model_lib=\"mock://echo\")\n    param_dict = {\n        \"top_p\": 0.6,\n        \"temperature\": 0.9,\n        \"frequency_penalty\": 0.1,\n        \"presence_penalty\": 0.1,\n        \"n\": 2,\n    }\n    response = engine.chat.completions.create(  # type: ignore\n        messages=[{\"role\": \"user\", \"content\": \"hello\"}],\n        **param_dict,\n    )\n    # echo mock will echo back the generation config\n    for k, v in param_dict.items():\n        assert response.usage.extra[k] == v\n\n\nif __name__ == \"__main__\":\n    test_completion_api()\n"
  },
  {
    "path": "tests/python/serve/test_serve_engine_prefix_cache.py",
    "content": "from mlc_llm.protocol.debug_protocol import DebugConfig\nfrom mlc_llm.protocol.generation_config import GenerationConfig\nfrom mlc_llm.serve.sync_engine import EngineConfig, SyncMLCEngine\nfrom mlc_llm.testing import require_test_model\n\nprompts = [\n    \"The meaning of life is\",\n    \"According to the history of Pittsburgh,\",\n    \"I have a three-day Seattle travel plan. On the first day,\",\n    \"Undoubtedly, Alaska is one of the most beautiful places on Earth,\",\n    \"Explain difference between Lambda calculus and Turing machine is\",\n    \"To assemble a desktop computer, we need the necessary components of\",\n    \"Vitamin D is important to human beings, because\",\n    \"Refer to history, the milk tea is originated from\",\n    \"In the southernmost place in United States,\",\n    \"AlphaGo has the capabilities of\",\n]\n\n\ndef test_engine_system_prompt(engine):\n    system_prompt = \"This is a system prompt\"\n    system_prompt_tokens = len(engine.tokenizer.encode(system_prompt))\n    max_tokens = 8\n    _, _ = engine.generate(\n        system_prompt,\n        GenerationConfig(\n            temperature=0,\n            max_tokens=max_tokens,\n            debug_config=DebugConfig(pinned_system_prompt=True),\n        ),\n    )\n    metrics = engine.metrics()\n    assert metrics[\"prefill_tokens_sum\"] == system_prompt_tokens\n    sum_prefill_tokens = system_prompt_tokens\n\n    input_token_lens = [len(engine.tokenizer.encode(prompt)) for prompt in prompts]\n\n    generation_config = GenerationConfig(temperature=0, max_tokens=max_tokens)\n    _, _ = engine.generate(prompts, generation_config)\n    metrics = engine.metrics()\n    assert metrics[\"prefill_tokens_sum\"] == sum_prefill_tokens + sum(input_token_lens)\n    sum_prefill_tokens = metrics[\"prefill_tokens_sum\"]\n\n    _, _ = engine.generate(system_prompt + \" and why ?\", generation_config)\n    metrics = engine.metrics()\n    # system prompt is reused entirely\n    assert metrics[\"prefill_tokens_sum\"] == sum_prefill_tokens + 3\n    sum_prefill_tokens = metrics[\"prefill_tokens_sum\"]\n\n    _, _ = engine.generate(prompts[:4], generation_config)\n    metrics = engine.metrics()\n    # first 4 prompts are removed and need to prefill again\n    assert metrics[\"prefill_tokens_sum\"] == sum_prefill_tokens + sum(input_token_lens[:4])\n\n\ndef test_engine_multi_round(engine):\n    num_requests = 10\n    max_tokens = 8\n    generation_config = GenerationConfig(temperature=0, max_tokens=max_tokens)\n    input_token_lens = [len(engine.tokenizer.encode(prompt)) for prompt in prompts[:num_requests]]\n\n    output_texts, _ = engine.generate(prompts[:num_requests], generation_config)\n    metrics = engine.metrics()\n    assert metrics[\"prefill_tokens_sum\"] == sum(input_token_lens)\n    sum_prefill_tokens = metrics[\"prefill_tokens_sum\"]\n    concat_prompt = []\n    for i, output in enumerate(output_texts):\n        concat_prompt.append(prompts[i] + \" \" + output[0] + \" ?\")\n    output_texts, _ = engine.generate(concat_prompt[:num_requests], generation_config)\n    metrics = engine.metrics()\n    assert metrics[\"prefill_tokens_sum\"] == sum_prefill_tokens + 2 * num_requests\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\ndef test_basic_engine_system_prompt(model: str):\n    # Create engine\n    engine = SyncMLCEngine(\n        model=model,\n        mode=\"local\",\n        engine_config=EngineConfig(\n            max_total_sequence_length=4096,\n            prefix_cache_max_num_recycling_seqs=5,\n        ),\n    )\n    test_engine_system_prompt(engine)\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\ndef test_basic_engine_multi_round(model: str):\n    # Create engine\n    engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(max_total_sequence_length=4096),\n    )\n    test_engine_multi_round(engine)\n\n\n@require_test_model(\n    \"Llama-2-7b-chat-hf-q0f16-MLC\",\n    \"Llama-2-7b-chat-hf-q4f16_1-MLC\",\n)\ndef test_engine_spec_multi_round(model: str, small_model: str):\n    # Create engine\n    engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_total_sequence_length=4096,\n            additional_models=[small_model],\n            speculative_mode=\"small_draft\",\n        ),\n    )\n\n    test_engine_multi_round(engine)\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\ndef test_engine_eagle_multi_round(model: str):\n    # Create engine\n    small_model = \"dist/Eagle-llama2-7b-chat-q0f16-MLC\"\n    small_model_lib = \"dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so\"\n    engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_total_sequence_length=4096,\n            additional_models=[(small_model, small_model_lib)],\n            speculative_mode=\"eagle\",\n            max_num_sequence=80,\n        ),\n    )\n\n    test_engine_multi_round(engine)\n\n\nif __name__ == \"__main__\":\n    test_basic_engine_system_prompt()\n    test_basic_engine_multi_round()\n    test_engine_spec_multi_round()\n    test_engine_eagle_multi_round()\n"
  },
  {
    "path": "tests/python/serve/test_serve_engine_rnn.py",
    "content": "# pylint: disable=chained-comparison,line-too-long,missing-docstring,\n# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable\nfrom typing import List\n\nfrom mlc_llm.protocol.generation_config import GenerationConfig\nfrom mlc_llm.serve import EngineConfig, MLCEngine\n\nprompts = [\n    \"What is the meaning of life?\",\n    \"Introduce the history of Pittsburgh to me. Please elaborate in detail.\",\n    \"Write a three-day Seattle travel plan. Please elaborate in detail.\",\n    \"What is Alaska famous of? Please elaborate in detail.\",\n    \"What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.\",\n    \"What are the necessary components to assemble a desktop computer? Please elaborate in detail.\",\n    \"Why is Vitamin D important to human beings? Please elaborate in detail.\",\n    \"Where is milk tea originated from? Please elaborate in detail.\",\n    \"Where is the southernmost place in United States? Please elaborate in detail.\",\n    \"Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.\",\n]\n\n\ndef test_engine_generate() -> None:\n    engine = MLCEngine(\n        model=\"dist/rwkv-6-world-1b6-q0f16-MLC\",\n        model_lib=\"dist/rwkv-6-world-1b6-q0f16-MLC/rwkv-6-world-1b6-q0f16-MLC-cuda.so\",\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_num_sequence=8,\n            max_history_size=1,\n        ),\n    )\n\n    num_requests = 10\n    max_tokens = 256\n    generation_cfg = GenerationConfig(max_tokens=max_tokens, n=7)\n\n    output_texts: List[List[str]] = [\n        [\"\" for _ in range(generation_cfg.n)] for _ in range(num_requests)\n    ]\n    for rid in range(num_requests):\n        print(f\"generating for request {rid}\")\n        for delta_outputs in engine._generate(prompts[rid], generation_cfg, request_id=str(rid)):\n            assert len(delta_outputs) == generation_cfg.n\n            for i, delta_output in enumerate(delta_outputs):\n                output_texts[rid][i] += delta_output.delta_text\n\n    # Print output.\n    print(\"All finished\")\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n    engine.terminate()\n    del engine\n\n\nif __name__ == \"__main__\":\n    test_engine_generate()\n"
  },
  {
    "path": "tests/python/serve/test_serve_engine_spec.py",
    "content": "# pylint: disable=chained-comparison,line-too-long,missing-docstring,\n# pylint: disable=too-many-arguments,too-many-locals\nfrom typing import Callable, List, Optional\n\nimport numpy as np\n\nfrom mlc_llm.protocol.generation_config import GenerationConfig\nfrom mlc_llm.serve import Request, RequestStreamOutput, data\nfrom mlc_llm.serve.sync_engine import EngineConfig, SyncMLCEngine\nfrom mlc_llm.testing import require_test_model\n\nprompts = [\n    \"What is the meaning of life?\",\n    \"Introduce the history of Pittsburgh to me. Please elaborate in detail.\",\n    \"Write a three-day Seattle travel plan. Please elaborate in detail.\",\n    \"What is Alaska famous of? Please elaborate in detail.\",\n    \"What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.\",\n    \"What are the necessary components to assemble a desktop computer? Please elaborate in detail.\",\n    \"Why is Vitamin D important to human beings? Please elaborate in detail.\",\n    \"Where is milk tea originated from? Please elaborate in detail.\",\n    \"Where is the southernmost place in United States? Please elaborate in detail.\",\n    \"Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.\",\n]\n\n\ndef create_requests(\n    num_requests: int,\n    stop_token_id: Optional[int] = None,\n    temperature: float = 0.8,\n    repetition_penalty: float = 1.0,\n    max_tokens_low: int = 256,\n    max_tokens_high: int = 257,\n) -> List[Request]:\n    assert num_requests >= 0 and num_requests <= len(prompts)\n\n    stop_token_ids = [stop_token_id] if stop_token_id is not None else []\n    requests = []\n    for req_id, prompt in zip(range(num_requests), prompts):\n        max_tokens = np.random.randint(max_tokens_low, max_tokens_high)\n        requests.append(\n            Request(\n                request_id=str(req_id),\n                inputs=data.TextData(prompt),\n                generation_config=GenerationConfig(\n                    temperature=temperature,\n                    repetition_penalty=repetition_penalty,\n                    max_tokens=max_tokens,\n                    stop_token_ids=stop_token_ids,\n                ),\n            )\n        )\n    return requests\n\n\n@require_test_model(\n    \"Llama-2-7b-chat-hf-q0f16-MLC\",\n    \"Llama-2-7b-chat-hf-q4f16_1-MLC\",\n)\ndef test_engine_basic(model: str, small_model: str):\n    \"\"\"Test engine **without continuous batching**.\n\n    - Add all requests to the engine altogether in the beginning.\n    - All requests have the same max_tokens. This means all requests\n    will end together.\n    - Engine keeps running `step` for estimated number of steps (number of\n    requests + max_tokens - 1). Then check the output of each request.\n    \"\"\"\n\n    # Hyperparameters for tests (you can try different combinations).\n    num_requests = len(prompts)  # [4, 8, 10]\n    temperature = 0.9  # [0, 0.8, 0.9, 1.0, 1.1]\n    repetition_penalty = 1.0  # [1.0, 1.01]\n    max_tokens: int = 256  # [32, 128, 256]\n    np.random.seed(0)\n\n    # Output list\n    outputs: List[List[int]] = [[] for _ in range(num_requests)]\n\n    # Define the callback function for request generation results\n    def fcallback(delta_outputs: List[RequestStreamOutput]):\n        for delta_output in delta_outputs:\n            request_id, stream_outputs = delta_output.unpack()\n            assert len(stream_outputs) == 1\n            outputs[int(request_id)] += stream_outputs[0].delta_token_ids\n\n    # Create engine\n    engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_total_sequence_length=4096,\n            additional_models=[small_model],\n            speculative_mode=\"small_draft\",\n        ),\n        request_stream_callback=fcallback,\n    )\n\n    # Create requests\n    requests = create_requests(\n        num_requests,\n        temperature=temperature,\n        repetition_penalty=repetition_penalty,\n        max_tokens_low=max_tokens,\n        max_tokens_high=max_tokens + 1,\n    )\n\n    # Add all requests to engine\n    for request in requests:\n        engine.add_request(request)\n\n    num_steps = num_requests + max_tokens - 1\n    # Run steps\n    for step in range(num_steps):\n        engine.step()\n\n    for req_id, output in enumerate(outputs):\n        print(f\"Prompt {req_id}: {requests[req_id].inputs[0]}\")\n        print(f\"Output {req_id}:{engine.tokenizer.decode(output)}\\n\")\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\ndef test_engine_eagle_basic(model: str):\n    \"\"\"Test engine **without continuous batching**.\n\n    - Add all requests to the engine altogether in the beginning.\n    - All requests have the same max_tokens. This means all requests\n    will end together.\n    - Engine keeps running `step` for estimated number of steps (number of\n    requests + max_tokens - 1). Then check the output of each request.\n    - Use Eagle model as speculative model\n    \"\"\"\n\n    # Hyperparameters for tests (you can try different combinations).\n    num_requests = len(prompts)  # [4, 8, 10]\n    temperature = 0.9  # [0, 0.8, 0.9, 1.0, 1.1]\n    repetition_penalty = 1.0  # [1.0, 1.01]\n    max_tokens: int = 256  # [32, 128, 256]\n    np.random.seed(0)\n\n    # Output list\n    outputs: List[List[int]] = [[] for _ in range(num_requests)]\n\n    # Define the callback function for request generation results\n    def fcallback(delta_outputs: List[RequestStreamOutput]):\n        for delta_output in delta_outputs:\n            request_id, stream_outputs = delta_output.unpack()\n            assert len(stream_outputs) == 1\n            outputs[int(request_id)] += stream_outputs[0].delta_token_ids\n\n    # Create engine\n    small_model = \"dist/Eagle-llama2-7b-chat-q0f16-MLC\"\n    small_model_lib = \"dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so\"\n    engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_total_sequence_length=4096,\n            additional_models=[(small_model, small_model_lib)],\n            speculative_mode=\"eagle\",\n            spec_draft_length=2,\n        ),\n        request_stream_callback=fcallback,\n    )\n\n    # Create requests\n    requests = create_requests(\n        num_requests,\n        temperature=temperature,\n        repetition_penalty=repetition_penalty,\n        max_tokens_low=max_tokens,\n        max_tokens_high=max_tokens + 1,\n    )\n\n    # Add all requests to engine\n    for request in requests:\n        engine.add_request(request)\n\n    num_steps = num_requests + max_tokens - 1\n    # Run steps\n    for step in range(num_steps):\n        engine.step()\n\n    for req_id, output in enumerate(outputs):\n        print(f\"Prompt {req_id}: {requests[req_id].inputs[0]}\")\n        print(f\"Output {req_id}:{engine.tokenizer.decode(output)}\\n\")\n\n\n@require_test_model(\n    \"Llama-2-7b-chat-hf-q0f16-MLC\",\n    \"Llama-2-7b-chat-hf-q4f16_1-MLC\",\n)\ndef test_engine_continuous_batching_1(model: str, small_model: str):\n    \"\"\"Test engine **with continuous batching**.\n\n    - Add all requests to the engine altogether in the beginning.\n    - All requests have a random maximum generation length. So each\n    request keeps generating until reaching the maximum length.\n    - Engine keeps running `step` for estimated number of steps (number of\n    requests + the maximum max_tokens - 1). Then check the output\n    of each request.\n    \"\"\"\n\n    # Hyperparameters for tests (you can try different combinations)\n    num_requests = len(prompts)  # [4, 8, 10]\n    temperature = 0.9  # [0.8, 0.9, 1.0, 1.1]\n    repetition_penalty = 1.00  # [1.0, 1.01]\n    max_tokens_low = 128\n    max_tokens_high = 384\n    np.random.seed(0)\n\n    # Output list\n    outputs: List[List[int]] = [[] for _ in range(num_requests)]\n    finish_time: List[Optional[int]] = [None] * num_requests\n\n    # Define the callback class for request generation results\n    class CallbackTimer:\n        timer: int = -1\n\n        def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]:\n            def fcallback(delta_outputs: List[RequestStreamOutput]):\n                for delta_output in delta_outputs:\n                    request_id, stream_outputs = delta_output.unpack()\n                    assert len(stream_outputs) == 1\n                    if stream_outputs[0].finish_reason is not None:\n                        print(f\"Request {request_id} finished at step {self.timer}.\")\n                    outputs[int(request_id)] += stream_outputs[0].delta_token_ids\n                    finish_time[int(request_id)] = self.timer\n\n            return fcallback\n\n        def step(self) -> None:\n            self.timer += 1\n\n    # Create engine\n    timer = CallbackTimer()\n    engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_total_sequence_length=4096,\n            additional_models=[small_model],\n            speculative_mode=\"small_draft\",\n        ),\n        request_stream_callback=timer.callback_getter(),\n    )\n\n    # Create requests\n    requests = create_requests(\n        num_requests,\n        temperature=temperature,\n        repetition_penalty=repetition_penalty,\n        max_tokens_low=max_tokens_low,\n        max_tokens_high=max_tokens_high,\n    )\n\n    # Add all requests to engine\n    for request in requests:\n        engine.add_request(request)\n\n    num_steps = num_requests + max(request.generation_config.max_tokens for request in requests) - 1\n    # Run steps\n    for step in range(num_steps):\n        timer.step()\n        assert timer.timer == step\n        engine.step()\n\n    for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)):\n        print(f\"Prompt {req_id}: {request.inputs[0]}\")\n        print(f\"Output {req_id}:{engine.tokenizer.decode(output)}\\n\")\n        # assert fin_time == request.generation_config.max_tokens - 1\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q4f16_1-MLC\")\ndef test_engine_eagle_continuous_batching_1(model: str):\n    \"\"\"Test engine **with continuous batching**.\n\n    - Add all requests to the engine altogether in the beginning.\n    - All requests have a random maximum generation length. So each\n    request keeps generating until reaching the maximum length.\n    - Engine keeps running `step` for estimated number of steps (number of\n    requests + the maximum max_tokens - 1). Then check the output\n    of each request.\n    \"\"\"\n\n    # Hyperparameters for tests (you can try different combinations)\n    num_requests = len(prompts)  # [4, 8, 10]\n    temperature = 0.9  # [0.8, 0.9, 1.0, 1.1]\n    repetition_penalty = 1.00  # [1.0, 1.01]\n    max_tokens_low = 128\n    max_tokens_high = 384\n    np.random.seed(0)\n\n    # Output list\n    outputs: List[List[int]] = [[] for _ in range(num_requests)]\n    finish_time: List[Optional[int]] = [None] * num_requests\n\n    # Define the callback class for request generation results\n    class CallbackTimer:\n        timer: int = -1\n\n        def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]:\n            def fcallback(delta_outputs: List[RequestStreamOutput]):\n                for delta_output in delta_outputs:\n                    request_id, stream_outputs = delta_output.unpack()\n                    assert len(stream_outputs) == 1\n                    if stream_outputs[0].finish_reason is not None:\n                        print(f\"Request {request_id} finished at step {self.timer}.\")\n                    outputs[int(request_id)] += stream_outputs[0].delta_token_ids\n                    finish_time[int(request_id)] = self.timer\n\n            return fcallback\n\n        def step(self) -> None:\n            self.timer += 1\n\n    # Create engine\n    small_model = \"dist/Eagle-llama2-7b-chat-q4f16_1-MLC\"\n    small_model_lib = (\n        \"dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so\"\n    )\n    timer = CallbackTimer()\n    engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_total_sequence_length=4096,\n            additional_models=[(small_model, small_model_lib)],\n            speculative_mode=\"eagle\",\n        ),\n        request_stream_callback=timer.callback_getter(),\n    )\n\n    # Create requests\n    requests = create_requests(\n        num_requests,\n        temperature=temperature,\n        repetition_penalty=repetition_penalty,\n        max_tokens_low=max_tokens_low,\n        max_tokens_high=max_tokens_high,\n    )\n\n    # Add all requests to engine\n    for request in requests:\n        engine.add_request(request)\n\n    num_steps = num_requests + max(request.generation_config.max_tokens for request in requests) - 1\n    # Run steps\n    for step in range(num_steps):\n        timer.step()\n        assert timer.timer == step\n        engine.step()\n\n    for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)):\n        print(f\"Prompt {req_id}: {request.inputs[0]}\")\n        print(f\"Output {req_id}:{engine.tokenizer.decode(output)}\\n\")\n        # assert fin_time == request.generation_config.max_tokens - 1\n\n\ndef compare_output_text(output_text1, output_text2):\n    if isinstance(output_text1, list) and isinstance(output_text2, list):\n        for item1, item2 in zip(output_text1, output_text2):\n            if not compare_output_text(item1, item2):\n                return False\n    elif output_text1 != output_text2:\n        print(output_text1)\n        print(output_text2)\n        return False\n    return True\n\n\n@require_test_model(\n    \"Llama-2-7b-chat-hf-q0f16-MLC\",\n    \"Llama-2-7b-chat-hf-q4f16_1-MLC\",\n)\ndef test_engine_generate(model: str, small_model: str, compare_precision=False):\n    # Create engine\n    engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_total_sequence_length=4096,\n            additional_models=[small_model],\n            speculative_mode=\"small_draft\",\n        ),\n    )\n\n    num_requests = 10\n    max_tokens = 256\n\n    # Generate output.\n    if compare_precision:\n        print(\"compare precision\")\n        generation_config = GenerationConfig(\n            temperature=0.0, top_p=0, max_tokens=1024, stop_token_ids=[2], n=1\n        )\n        engine_single_model = SyncMLCEngine(\n            model=model,\n            mode=\"server\",\n            engine_config=EngineConfig(\n                max_total_sequence_length=4096,\n            ),\n        )\n        output_texts_single_model, _ = engine_single_model.generate(\n            prompts[:num_requests], generation_config\n        )\n        for req_id, outputs in enumerate(output_texts_single_model):\n            print(f\"Prompt {req_id}: {prompts[req_id]}\")\n            if len(outputs) == 1:\n                print(f\"Output {req_id}:{outputs[0]}\\n\")\n            else:\n                for i, output in enumerate(outputs):\n                    print(f\"Output {req_id}({i}):{output}\\n\")\n        # TODO: Add pytorch precision\n    else:\n        generation_config = GenerationConfig(max_tokens=max_tokens, n=3)\n    output_texts, _ = engine.generate(prompts[:num_requests], generation_config)\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n    if compare_precision:\n        precision_flag = compare_output_text(output_texts, output_texts_single_model)\n        if precision_flag:\n            print(f\"Accuracy verification succeed\\n\")\n        else:\n            print(f\"Accuracy verification failed\\n\")\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\ndef test_engine_eagle_generate(model: str):\n    # Create engine\n    small_model = \"dist/Eagle-llama2-7b-chat-q4f16_1-MLC\"\n    small_model_lib = (\n        \"dist/Eagle-llama2-7b-chat-q4f16_1-MLC/Eagle-llama2-7b-chat-q4f16_1-MLC-cuda.so\"\n    )\n    engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_total_sequence_length=4096,\n            additional_models=[(small_model, small_model_lib)],\n            speculative_mode=\"eagle\",\n        ),\n    )\n\n    num_requests = 10\n    max_tokens = 256\n\n    # Generate output.\n    output_texts, _ = engine.generate(\n        prompts[:num_requests], GenerationConfig(max_tokens=max_tokens, n=3)\n    )\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n\n@require_test_model(\"Llama-2-13b-chat-hf-q4f16_1-MLC\")\ndef test_engine_efficiency(model: str):\n    \"\"\"Test engine speculative decoding efficiency.\"\"\"\n\n    # Hyperparameters for tests (you can try different combinations).\n    num_requests = 1  # [4, 8, 10]\n    temperature = 0.9  # [0, 0.8, 0.9, 1.0, 1.1]\n    repetition_penalty = 1.0  # [1.0, 1.01]\n    max_tokens: int = 512\n    np.random.seed(0)\n\n    # Output list\n    outputs: List[List[int]] = [[] for _ in range(num_requests)]\n\n    # Define the callback function for request generation results\n    def fcallback(delta_outputs: List[RequestStreamOutput]):\n        for delta_output in delta_outputs:\n            request_id, stream_outputs = delta_output.unpack()\n            assert len(stream_outputs) == 1\n            outputs[int(request_id)] += stream_outputs[0].delta_token_ids\n\n    # Create engine\n    engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(max_total_sequence_length=4096),\n        request_stream_callback=fcallback,\n    )\n\n    # Create requests\n    requests = create_requests(\n        num_requests,\n        temperature=temperature,\n        repetition_penalty=repetition_penalty,\n        max_tokens_low=max_tokens,\n        max_tokens_high=max_tokens + 1,\n    )\n\n    # Add all requests to engine\n    for request in requests:\n        engine.add_request(request)\n\n    num_steps = num_requests + max_tokens - 1\n    # Run steps\n    for step in range(num_steps):\n        engine.step()\n\n    for eg, name in zip([engine], [\"Normal Deconding\"]):\n        metrics = eg.metrics()\n        print(\"engine name:\", name)\n        if name == \"Speculative Decoding\":\n            print(\"spec decode metrics:\", metrics[\"spec_decode\"])\n        print(\"engine total decode time:\", metrics[\"engine_decode_time_sum\"])\n        print()\n\n\n@require_test_model(\n    \"Llama-2-13b-chat-hf-q4f16_1-MLC\",\n    \"Llama-2-7b-chat-hf-q4f16_1-MLC\",\n)\ndef test_engine_spec_efficiency(model: str, small_model: str):\n    \"\"\"Test engine speculative decoding efficiency.\"\"\"\n\n    # Hyperparameters for tests (you can try different combinations).\n    num_requests = 1  # [4, 8, 10]\n    temperature = 0.9  # [0, 0.8, 0.9, 1.0, 1.1]\n    repetition_penalty = 1.0  # [1.0, 1.01]\n    max_tokens: int = 512\n    np.random.seed(0)\n\n    # Output list\n    outputs: List[List[int]] = [[] for _ in range(num_requests)]\n\n    # Define the callback function for request generation results\n    def fcallback(delta_outputs: List[RequestStreamOutput]):\n        for delta_output in delta_outputs:\n            request_id, stream_outputs = delta_output.unpack()\n            assert len(stream_outputs) == 1\n            outputs[int(request_id)] += stream_outputs[0].delta_token_ids\n\n    # Create engine\n    spec_engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_total_sequence_length=4096,\n            additional_models=[small_model],\n            spec_draft_length=6,\n            speculative_mode=\"small_draft\",\n        ),\n        request_stream_callback=fcallback,\n    )\n\n    # Create requests\n    requests = create_requests(\n        num_requests,\n        temperature=temperature,\n        repetition_penalty=repetition_penalty,\n        max_tokens_low=max_tokens,\n        max_tokens_high=max_tokens + 1,\n    )\n\n    # Add all requests to engine\n    for request in requests:\n        spec_engine.add_request(request)\n\n    num_steps = num_requests + max_tokens - 1\n    # Run steps\n    for step in range(num_steps):\n        spec_engine.step()\n\n    for eg, name in zip([spec_engine], [\"Speculative Decoding\"]):\n        metrics = eg.metrics()\n        print(\"engine name:\", name)\n        if name == \"Speculative Decoding\":\n            print(\"total draft tokens:\", metrics[\"sum_num_draft_tokens\"])\n            print(\"total accepted tokens:\", metrics[\"sum_num_accepted_tokens\"])\n            print(\n                \"Accept rate:\",\n                metrics[\"sum_num_accepted_tokens\"] / (1e-10 + metrics[\"sum_num_draft_tokens\"]),\n            )\n        print(\"engine total decode time:\", metrics[\"engine_decode_time_sum\"])\n        print()\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q4f16_1-MLC\")\ndef test_engine_eagle_spec_efficiency(model: str):\n    \"\"\"Test engine speculative decoding efficiency.\"\"\"\n\n    # Hyperparameters for tests (you can try different combinations).\n    num_requests = 1  # [4, 8, 10]\n    temperature = 0.9  # [0, 0.8, 0.9, 1.0, 1.1]\n    repetition_penalty = 1.0  # [1.0, 1.01]\n    max_tokens: int = 512\n    np.random.seed(0)\n\n    # Output list\n    outputs: List[List[int]] = [[] for _ in range(num_requests)]\n\n    # Define the callback function for request generation results\n    def fcallback(delta_outputs: List[RequestStreamOutput]):\n        for delta_output in delta_outputs:\n            request_id, stream_outputs = delta_output.unpack()\n            assert len(stream_outputs) == 1\n            outputs[int(request_id)] += stream_outputs[0].delta_token_ids\n\n    # Create engine\n    small_model = \"dist/Eagle-llama2-7b-chat-q0f16-MLC\"\n    small_model_lib = \"dist/Eagle-llama2-7b-chat-q0f16-MLC/Eagle-llama2-7b-chat-q0f16-MLC-cuda.so\"\n    spec_engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(\n            max_total_sequence_length=4096,\n            additional_models=[(small_model, small_model_lib)],\n            spec_draft_length=6,\n            speculative_mode=\"eagle\",\n        ),\n        request_stream_callback=fcallback,\n    )\n\n    # Create requests\n    requests = create_requests(\n        num_requests,\n        temperature=temperature,\n        repetition_penalty=repetition_penalty,\n        max_tokens_low=max_tokens,\n        max_tokens_high=max_tokens + 1,\n    )\n\n    # Add all requests to engine\n    for request in requests:\n        spec_engine.add_request(request)\n\n    num_steps = num_requests + max_tokens - 1\n    # Run steps\n    for step in range(num_steps):\n        spec_engine.step()\n\n    for eg, name in zip([spec_engine], [\"Speculative Decoding\"]):\n        metrics = eg.metrics()\n        print(\"engine name:\", name)\n        if name == \"Speculative Decoding\":\n            print(\"spec decode:\", metrics[\"spec_decode\"])\n        print(\"engine total decode time:\", metrics[\"engine_decode_time_sum\"])\n        print()\n\n\nif __name__ == \"__main__\":\n    test_engine_basic()\n    test_engine_eagle_basic()\n    test_engine_continuous_batching_1()\n    test_engine_eagle_continuous_batching_1()\n    test_engine_generate(compare_precision=True)\n    test_engine_eagle_generate()\n    test_engine_efficiency()\n    test_engine_spec_efficiency()\n    test_engine_eagle_spec_efficiency()\n"
  },
  {
    "path": "tests/python/serve/test_serve_sync_engine.py",
    "content": "# pylint: disable=chained-comparison,line-too-long,missing-docstring,\n# pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable\nfrom typing import Callable, List, Optional\n\nimport numpy as np\n\nfrom mlc_llm.protocol.generation_config import GenerationConfig\nfrom mlc_llm.serve import Request, RequestStreamOutput, data\nfrom mlc_llm.serve.sync_engine import EngineConfig, SyncMLCEngine\nfrom mlc_llm.testing import require_test_model\n\nprompts = [\n    \"What is the meaning of life?\",\n    \"Introduce the history of Pittsburgh to me. Please elaborate in detail.\",\n    \"Write a three-day Seattle travel plan. Please elaborate in detail.\",\n    \"What is Alaska famous of? Please elaborate in detail.\",\n    \"What is the difference between Lambda calculus and Turing machine? Please elaborate in detail.\",\n    \"What are the necessary components to assemble a desktop computer? Please elaborate in detail.\",\n    \"Why is Vitamin D important to human beings? Please elaborate in detail.\",\n    \"Where is milk tea originated from? Please elaborate in detail.\",\n    \"Where is the southernmost place in United States? Please elaborate in detail.\",\n    \"Do you know AlphaGo? What capabilities does it have, and what achievements has it got? Please elaborate in detail.\",\n]\n\n\ndef create_requests(\n    engine: SyncMLCEngine,\n    num_requests: int,\n    stop_token_id: Optional[int] = None,\n    temperature: float = 0.8,\n    repetition_penalty: float = 1.0,\n    max_tokens_low: int = 256,\n    max_tokens_high: int = 257,\n) -> List[Request]:\n    assert num_requests >= 0 and num_requests <= len(prompts)\n\n    stop_token_ids = [stop_token_id] if stop_token_id is not None else []\n    requests = []\n    for req_id, prompt in zip(range(num_requests), prompts):\n        max_tokens = np.random.randint(max_tokens_low, max_tokens_high)\n        requests.append(\n            engine.create_request(\n                request_id=str(req_id),\n                inputs=data.TextData(prompt),\n                generation_config=GenerationConfig(\n                    temperature=temperature,\n                    repetition_penalty=repetition_penalty,\n                    max_tokens=max_tokens,\n                    stop_token_ids=stop_token_ids,\n                ),\n            )\n        )\n    return requests\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\ndef test_engine_basic(model: str):\n    \"\"\"Test engine **without continuous batching**.\n\n    - Add all requests to the engine altogether in the beginning.\n    - All requests have the same max_tokens. This means all requests\n    will end together.\n    - Engine keeps running `step` for estimated number of steps (number of\n    requests + max_tokens - 1). Then check the output of each request.\n    \"\"\"\n\n    # Hyperparameters for tests (you can try different combinations).\n    num_requests = 10  # [4, 8, 10]\n    temperature = 0.9  # [0, 0.8, 0.9, 1.0, 1.1]\n    repetition_penalty = 1.0  # [1.0, 1.01]\n    max_tokens: int = 256  # [32, 128, 256]\n    np.random.seed(0)\n\n    # Output list\n    outputs: List[List[int]] = [[] for _ in range(num_requests)]\n\n    # Define the callback function for request generation results\n    def fcallback(delta_outputs: List[RequestStreamOutput]):\n        for delta_output in delta_outputs:\n            request_id, stream_outputs = delta_output.unpack()\n            assert len(stream_outputs) == 1\n            outputs[int(request_id)] += stream_outputs[0].delta_token_ids\n\n    # Create engine\n    engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        request_stream_callback=fcallback,\n    )\n\n    # Create requests\n    requests = create_requests(\n        engine,\n        num_requests,\n        temperature=temperature,\n        repetition_penalty=repetition_penalty,\n        max_tokens_low=max_tokens,\n        max_tokens_high=max_tokens + 1,\n    )\n\n    # Add all requests to engine\n    for request in requests:\n        engine.add_request(request)\n\n    num_steps = num_requests + max_tokens - 1\n    # Run steps\n    for step in range(num_steps):\n        engine.step()\n\n    for req_id, output in enumerate(outputs):\n        print(f\"Prompt {req_id}: {requests[req_id].inputs[0]}\")\n        print(f\"Output {req_id}:{engine.tokenizer.decode(output)}\\n\")\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\ndef test_engine_continuous_batching_1(model: str):\n    \"\"\"Test engine **with continuous batching**.\n\n    - Add all requests to the engine altogether in the beginning.\n    - All requests have a random maximum generation length. So each\n    request keeps generating until reaching the maximum length.\n    - Engine keeps running `step` for estimated number of steps (number of\n    requests + the maximum max_tokens - 1). Then check the output\n    of each request.\n    \"\"\"\n\n    # Hyperparameters for tests (you can try different combinations)\n    num_requests = 10  # [4, 8, 10]\n    temperature = 0.9  # [0.8, 0.9, 1.0, 1.1]\n    repetition_penalty = 1.00  # [1.0, 1.01]\n    max_tokens_low = 128\n    max_tokens_high = 384\n    np.random.seed(0)\n\n    # Output list\n    outputs: List[List[int]] = [[] for _ in range(num_requests)]\n    finish_time: List[Optional[int]] = [None] * num_requests\n\n    # Define the callback class for request generation results\n    class CallbackTimer:\n        timer: int = -1\n\n        def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]:\n            def fcallback(delta_outputs: List[RequestStreamOutput]):\n                for delta_output in delta_outputs:\n                    request_id, stream_outputs = delta_output.unpack()\n                    assert len(stream_outputs) == 1\n                    if stream_outputs[0].finish_reason is not None:\n                        print(f\"Request {request_id} finished at step {self.timer}.\")\n                    outputs[int(request_id)] += stream_outputs[0].delta_token_ids\n                    finish_time[int(request_id)] = self.timer\n\n            return fcallback\n\n        def step(self) -> None:\n            self.timer += 1\n\n    # Create engine\n    timer = CallbackTimer()\n    engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        request_stream_callback=timer.callback_getter(),\n    )\n\n    # Create requests\n    requests = create_requests(\n        engine,\n        num_requests,\n        temperature=temperature,\n        repetition_penalty=repetition_penalty,\n        max_tokens_low=max_tokens_low,\n        max_tokens_high=max_tokens_high,\n    )\n\n    # Add all requests to engine\n    for request in requests:\n        engine.add_request(request)\n\n    num_steps = num_requests + max(request.generation_config.max_tokens for request in requests) - 1\n    # Run steps\n    for step in range(num_steps):\n        timer.step()\n        assert timer.timer == step\n        engine.step()\n\n    for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)):\n        print(f\"Prompt {req_id}: {request.inputs[0]}\")\n        print(f\"Output {req_id}:{engine.tokenizer.decode(output)}\\n\")\n        assert (\n            fin_time == request.generation_config.max_tokens - 1\n        ), f\"finish time = {fin_time}, max tokens = {request.generation_config.max_tokens - 1}\"\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\ndef test_engine_continuous_batching_2(model: str):\n    \"\"\"Test engine **with continuous batching**.\n\n    - Add all requests to the engine altogether in the beginning.\n    - All requests have the stop token. So each request keeps generating\n    until having the stop token or reaching the maximum length.\n    - Engine keeps running `step` for estimated number of steps (number of\n    requests + the maximum max_tokens - 1). Then check the output\n    of each request.\n    \"\"\"\n\n    # Hyperparameters for tests (you can try different combinations)\n    num_requests = 10  # [4, 8, 10]\n    temperature = 0.9  # [0.8, 0.9, 1.0, 1.1]\n    repetition_penalty = 1.00  # [1.0, 1.01]\n    stop_token_id = 2\n    max_tokens = 512\n    np.random.seed(0)\n\n    # Output list\n    outputs: List[List[int]] = [[] for _ in range(num_requests)]\n    finish_time: List[Optional[int]] = [None] * num_requests\n\n    # Define the callback class for request generation results\n    class CallbackTimer:\n        timer: int = -1\n\n        def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]:\n            def fcallback(delta_outputs: List[RequestStreamOutput]):\n                for delta_output in delta_outputs:\n                    request_id, stream_outputs = delta_output.unpack()\n                    assert len(stream_outputs) == 1\n                    if stream_outputs[0].finish_reason is not None:\n                        print(f\"Request {request_id} finished at step {self.timer}.\")\n                    outputs[int(request_id)] += stream_outputs[0].delta_token_ids\n                    finish_time[int(request_id)] = self.timer\n\n            return fcallback\n\n        def step(self) -> None:\n            self.timer += 1\n\n    # Create engine\n    timer = CallbackTimer()\n    engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        request_stream_callback=timer.callback_getter(),\n    )\n\n    # Create requests\n    requests = create_requests(\n        engine,\n        num_requests,\n        stop_token_id=stop_token_id,\n        temperature=temperature,\n        repetition_penalty=repetition_penalty,\n        max_tokens_low=max_tokens,\n        max_tokens_high=max_tokens + 1,\n    )\n\n    # Add all requests to engine\n    for request in requests:\n        engine.add_request(request)\n\n    num_steps = num_requests + max_tokens - 1\n    # Run steps\n    for step in range(num_steps):\n        timer.step()\n        assert timer.timer == step\n        engine.step()\n\n    for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)):\n        print(f\"Prompt {req_id}: {request.inputs[0]}\")\n        if fin_time < num_requests + max_tokens - 2:\n            print(f\"Request {req_id} ends early on the stop token\")\n        print(f\"Output {req_id}:{engine.tokenizer.decode(output)}\\n\")\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\ndef test_engine_continuous_batching_3(model: str):\n    \"\"\"Test engine **with continuous batching**.\n\n    - Add requests randomly between time [0, 200).\n    - All requests have a random maximum generation length. So each\n    request keeps generating until reaching the maximum length.\n    - Engine keeps running `step` until all requests finish.\n    Then check the output of each request.\n    \"\"\"\n\n    # Hyperparameters for tests (you can try different combinations)\n    num_requests = 10  # [4, 8, 10]\n    temperature = 0.9  # [0.8, 0.9, 1.0, 1.1]\n    repetition_penalty = 1.00  # [1.0, 1.01]\n    stop_token_id = 2\n    max_tokens_low = 64\n    max_tokens_high = 192\n    np.random.seed(0)\n\n    # Output list\n    outputs: List[List[int]] = [[] for _ in range(num_requests)]\n    finish_time: List[Optional[int]] = [None] * num_requests\n\n    # Define the callback class for request generation results\n    class CallbackTimer:\n        timer: int = -1\n        finished_requests: int = 0\n\n        def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]:\n            def fcallback(delta_outputs: List[RequestStreamOutput]):\n                for delta_output in delta_outputs:\n                    request_id, stream_outputs = delta_output.unpack()\n                    assert len(stream_outputs) == 1\n                    if stream_outputs[0].finish_reason is not None:\n                        print(f\"Request {request_id} finished at step {self.timer}.\")\n                        self.finished_requests += 1\n                    outputs[int(request_id)] += stream_outputs[0].delta_token_ids\n                    finish_time[int(request_id)] = self.timer\n\n            return fcallback\n\n        def step(self) -> None:\n            self.timer += 1\n\n        def all_finished(self) -> bool:\n            return self.finished_requests == num_requests\n\n    # Create engine\n    timer = CallbackTimer()\n    engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        request_stream_callback=timer.callback_getter(),\n    )\n\n    # Create requests\n    requests = create_requests(\n        engine,\n        num_requests,\n        stop_token_id=stop_token_id,\n        temperature=temperature,\n        repetition_penalty=repetition_penalty,\n        max_tokens_low=max_tokens_low,\n        max_tokens_high=max_tokens_high,\n    )\n\n    # Assign the time to add requests to engine\n    request_add_time = [np.random.randint(0, 200) for _ in range(num_requests)]\n\n    # Run steps\n    while not timer.all_finished():\n        timer.step()\n\n        # Add requests to engine\n        for req_id, add_time in enumerate(request_add_time):\n            if add_time == timer.timer:\n                print(f\"add request {req_id} at step {timer.timer}\")\n                engine.add_request(requests[req_id])\n\n        engine.step()\n\n    for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)):\n        print(f\"Prompt {req_id}: {request.inputs[0]}\")\n        print(f\"Finish time: {fin_time}\")\n        print(f\"Output {req_id}:{engine.tokenizer.decode(output)}\\n\")\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\ndef test_engine_generate(model: str):\n    # Create engine\n    engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        engine_config=EngineConfig(max_total_sequence_length=4096),\n    )\n\n    num_requests = 10\n    max_tokens = 256\n\n    # Generate output.\n    output_texts, _ = engine.generate(\n        prompts[:num_requests], GenerationConfig(max_tokens=max_tokens, n=7)\n    )\n    for req_id, outputs in enumerate(output_texts):\n        print(f\"Prompt {req_id}: {prompts[req_id]}\")\n        if len(outputs) == 1:\n            print(f\"Output {req_id}:{outputs[0]}\\n\")\n        else:\n            for i, output in enumerate(outputs):\n                print(f\"Output {req_id}({i}):{output}\\n\")\n\n\n@require_test_model(\"Llama-2-7b-chat-hf-q0f16-MLC\")\ndef test_engine_hybrid_prefill(model: str):\n    \"\"\"Test engine **with hybrid prefill**.\n\n    - Add each single request step by step.\n    - All requests have the same generation length. But due to hybrid prefill,\n    the earlier request will decode with later request prefill, in single step.\n    So each request lasts the same steps, and stops generation step by step as well.\n    - Engine keeps running `step` for the generation length, to finish the last request.\n    Then check the output of each request.\n    \"\"\"\n\n    # Hyperparameters for tests (you can try different combinations)\n    num_requests = 10  # [4, 8, 10]\n    temperature = 0.9  # [0.8, 0.9, 1.0, 1.1]\n    repetition_penalty = 1.00  # [1.0, 1.01]\n    max_tokens = 15\n    np.random.seed(0)\n\n    # Output list\n    outputs: List[List[int]] = [[] for _ in range(num_requests)]\n    finish_time: List[Optional[int]] = [None] * num_requests\n\n    # Define the callback class for request generation results\n    class CallbackTimer:\n        timer: int = -1\n\n        def callback_getter(self) -> Callable[[List[RequestStreamOutput]], None]:\n            def fcallback(delta_outputs: List[RequestStreamOutput]):\n                for delta_output in delta_outputs:\n                    request_id, stream_outputs = delta_output.unpack()\n                    assert len(stream_outputs) == 1\n                    if stream_outputs[0].finish_reason is not None:\n                        print(f\"Request {request_id} finished at step {self.timer}.\")\n                    outputs[int(request_id)] += stream_outputs[0].delta_token_ids\n                    finish_time[int(request_id)] = self.timer\n\n            return fcallback\n\n        def step(self) -> None:\n            self.timer += 1\n\n    # Create engine\n    timer = CallbackTimer()\n    engine = SyncMLCEngine(\n        model=model,\n        mode=\"server\",\n        request_stream_callback=timer.callback_getter(),\n        engine_config=EngineConfig(prefill_mode=\"hybrid\"),\n    )\n\n    # Create requests\n    requests = create_requests(\n        engine,\n        num_requests,\n        temperature=temperature,\n        repetition_penalty=repetition_penalty,\n        max_tokens_low=max_tokens,\n        max_tokens_high=max_tokens + 1,\n    )\n\n    # Add all requests to engine step by step\n    for step, request in enumerate(requests):\n        engine.add_request(request)\n        timer.step()\n        assert timer.timer == step\n        engine.step()\n\n    # Run steps\n    for step in range(max_tokens):\n        timer.step()\n        assert timer.timer == step + num_requests\n        engine.step()\n\n    for req_id, (request, output, fin_time) in enumerate(zip(requests, outputs, finish_time)):\n        print(f\"Prompt {req_id}: {request.inputs[0]}\")\n        print(f\"Output {req_id}:{engine.tokenizer.decode(output)}\\n\")\n        assert (\n            fin_time == req_id + request.generation_config.max_tokens - 1\n        ), f\"finish time = {fin_time}, max tokens = {req_id + request.generation_config.max_tokens - 1}\"\n\n\nif __name__ == \"__main__\":\n    test_engine_basic()\n    test_engine_continuous_batching_1()\n    test_engine_continuous_batching_2()\n    test_engine_continuous_batching_3()\n    test_engine_generate()\n    test_engine_hybrid_prefill()\n"
  },
  {
    "path": "tests/python/support/test_auto_config.py",
    "content": "# pylint: disable=missing-docstring\nimport json\nimport tempfile\nfrom pathlib import Path\n\nimport pytest\n\nfrom mlc_llm.support import logging\nfrom mlc_llm.support.auto_config import detect_config\n\nlogging.enable_logging()\n\n# test category \"unittest\"\npytestmark = [pytest.mark.unittest]\n\n\ndef _create_json_file(json_path, data):\n    with open(json_path, \"w\", encoding=\"utf-8\") as i_f:\n        json.dump(data, i_f)\n\n\ndef test_detect_config():\n    with tempfile.TemporaryDirectory() as tmpdir:\n        base_path = Path(tmpdir)\n        config_json_path = base_path / \"config.json\"\n        _create_json_file(config_json_path, {})\n\n        assert detect_config(base_path) == config_json_path\n        assert detect_config(config_json_path) == config_json_path\n\n\ndef test_detect_config_fail():\n    with pytest.raises(ValueError):\n        detect_config(Path(\"do/not/exist\"))\n\n    with tempfile.TemporaryDirectory() as tmpdir:\n        base_path = Path(tmpdir)\n        with pytest.raises(ValueError):\n            assert detect_config(base_path)\n\n\nif __name__ == \"__main__\":\n    test_detect_config()\n    test_detect_config_fail()\n"
  },
  {
    "path": "tests/python/support/test_auto_weight.py",
    "content": "# pylint: disable=missing-docstring\nimport json\nimport os\nimport tempfile\nfrom pathlib import Path\n\nimport pytest\n\nfrom mlc_llm.support import logging\nfrom mlc_llm.support.auto_weight import detect_weight\n\nlogging.enable_logging()\n\n# test category \"unittest\"\npytestmark = [pytest.mark.unittest]\n\n\ndef _create_json_file(json_path, data):\n    with open(json_path, \"w\", encoding=\"utf-8\") as i_f:\n        json.dump(data, i_f)\n\n\n@pytest.mark.parametrize(\n    \"weight_format, index_filename, result\",\n    [\n        (\"huggingface-torch\", \"pytorch_model.bin.index.json\", \"huggingface-torch\"),\n        (\n            \"huggingface-safetensor\",\n            \"model.safetensors.index.json\",\n            \"huggingface-safetensor\",\n        ),\n        (\"auto\", \"pytorch_model.bin.index.json\", \"huggingface-torch\"),\n        (\"auto\", \"model.safetensors.index.json\", \"huggingface-safetensor\"),\n    ],\n)\ndef test_detect_weight(weight_format, index_filename, result):\n    with tempfile.TemporaryDirectory() as tmpdir:\n        base_path = Path(tmpdir)\n        if index_filename is not None:\n            weight_index_file = base_path / index_filename\n            _create_json_file(weight_index_file, {})\n        assert detect_weight(base_path, None, weight_format) == (\n            weight_index_file,\n            result,\n        )\n\n\n@pytest.mark.parametrize(\n    \"weight_format, index_filename, result\",\n    [\n        (\"huggingface-torch\", \"pytorch_model.bin.index.json\", \"huggingface-torch\"),\n        (\n            \"huggingface-safetensor\",\n            \"model.safetensors.index.json\",\n            \"huggingface-safetensor\",\n        ),\n        (\"auto\", \"pytorch_model.bin.index.json\", \"huggingface-torch\"),\n        (\"auto\", \"model.safetensors.index.json\", \"huggingface-safetensor\"),\n    ],\n)\ndef test_detect_weight_in_config_json(weight_format, index_filename, result):\n    with (\n        tempfile.TemporaryDirectory() as config_dir,\n        tempfile.TemporaryDirectory() as weight_dir,\n    ):\n        config_path = Path(config_dir)\n        weight_path = Path(weight_dir)\n        config_json_path = config_path / \"config.json\"\n        _create_json_file(config_json_path, {\"weight_path\": weight_dir})\n        if index_filename is not None:\n            weight_index_file = weight_path / index_filename\n            _create_json_file(weight_index_file, {})\n\n        assert detect_weight(None, config_json_path, weight_format) == (\n            weight_index_file,\n            result,\n        )\n\n\n@pytest.mark.parametrize(\n    \"weight_format, index_filename, result\",\n    [\n        (\"huggingface-torch\", \"pytorch_model.bin.index.json\", \"huggingface-torch\"),\n        (\n            \"huggingface-safetensor\",\n            \"model.safetensors.index.json\",\n            \"huggingface-safetensor\",\n        ),\n        (\"auto\", \"pytorch_model.bin.index.json\", \"huggingface-torch\"),\n        (\"auto\", \"model.safetensors.index.json\", \"huggingface-safetensor\"),\n    ],\n)\ndef test_detect_weight_same_dir_config_json(weight_format, index_filename, result):\n    with tempfile.TemporaryDirectory() as tmpdir:\n        base_path = Path(tmpdir)\n        config_json_path = base_path / \"config.json\"\n        _create_json_file(config_json_path, {})\n        if index_filename is not None:\n            weight_index_file = Path(os.path.join(tmpdir, index_filename))\n            _create_json_file(weight_index_file, {})\n        assert detect_weight(None, config_json_path, weight_format) == (\n            weight_index_file,\n            result,\n        )\n\n\ndef test_find_weight_fail():\n    with tempfile.TemporaryDirectory() as tmpdir:\n        base_path = Path(tmpdir)\n        with pytest.raises(ValueError):\n            detect_weight(Path(\"do/not/exist\"), base_path, \"awq\")\n    with pytest.raises(AssertionError):\n        detect_weight(None, Path(\"do/not/exist\"), \"awq\")\n\n\nif __name__ == \"__main__\":\n    test_detect_weight(\"huggingface-torch\", \"pytorch_model.bin.index.json\", \"huggingface-torch\")\n    test_detect_weight(\n        \"huggingface-safetensor\",\n        \"model.safetensors.index.json\",\n        \"huggingface-safetensor\",\n    )\n    test_detect_weight(\"auto\", \"pytorch_model.bin.index.json\", \"huggingface-torch\")\n    test_detect_weight(\"auto\", \"model.safetensors.index.json\", \"huggingface-safetensor\")\n    test_detect_weight_in_config_json(\n        \"huggingface-torch\", \"pytorch_model.bin.index.json\", \"huggingface-torch\"\n    )\n    test_detect_weight_in_config_json(\n        \"huggingface-safetensor\",\n        \"model.safetensors.index.json\",\n        \"huggingface-safetensor\",\n    )\n    test_detect_weight_in_config_json(\"auto\", \"pytorch_model.bin.index.json\", \"huggingface-torch\")\n    test_detect_weight_in_config_json(\n        \"auto\", \"model.safetensors.index.json\", \"huggingface-safetensor\"\n    )\n    test_detect_weight_same_dir_config_json(\n        \"huggingface-torch\", \"pytorch_model.bin.index.json\", \"huggingface-torch\"\n    )\n    test_detect_weight_same_dir_config_json(\n        \"huggingface-safetensor\",\n        \"model.safetensors.index.json\",\n        \"huggingface-safetensor\",\n    )\n    test_detect_weight_same_dir_config_json(\n        \"auto\", \"pytorch_model.bin.index.json\", \"huggingface-torch\"\n    )\n    test_detect_weight_same_dir_config_json(\n        \"auto\", \"model.safetensors.index.json\", \"huggingface-safetensor\"\n    )\n    test_find_weight_fail()\n"
  },
  {
    "path": "tests/python/support/test_cli_convert_weight.py",
    "content": "# pylint: disable=missing-docstring\nimport json\nimport tempfile\nfrom pathlib import Path\n\nimport pytest\n\nfrom mlc_llm.cli import convert_weight as convert_weight_cli\n\npytestmark = [pytest.mark.unittest]\n\n\ndef test_convert_weight_cli_passes_lora_adapter(monkeypatch):\n    with tempfile.TemporaryDirectory() as tmp_dir:\n        temp_path = Path(tmp_dir)\n        config_path = temp_path / \"config.json\"\n        source_dir = temp_path / \"source\"\n        source_dir.mkdir(parents=True, exist_ok=True)\n        source_index = source_dir / \"pytorch_model.bin.index.json\"\n        adapter_dir = temp_path / \"adapter\"\n        adapter_dir.mkdir(parents=True, exist_ok=True)\n        output_dir = temp_path / \"output\"\n\n        config_path.write_text(json.dumps({}), encoding=\"utf-8\")\n        source_index.write_text(json.dumps({\"weight_map\": {}}), encoding=\"utf-8\")\n\n        def _fake_detect_device(device):\n            return device\n\n        def _fake_detect_weight(_weight_path, _config_json_path, _weight_format):\n            return source_index, \"huggingface-torch\"\n\n        def _fake_detect_model_type(_model_type, _config):\n            return \"dummy\"\n\n        monkeypatch.setattr(convert_weight_cli, \"detect_config\", Path)\n        monkeypatch.setattr(convert_weight_cli, \"detect_device\", _fake_detect_device)\n        monkeypatch.setattr(convert_weight_cli, \"detect_weight\", _fake_detect_weight)\n        monkeypatch.setattr(convert_weight_cli, \"detect_model_type\", _fake_detect_model_type)\n        monkeypatch.setattr(convert_weight_cli, \"MODELS\", {\"dummy\": object()})\n        monkeypatch.setattr(convert_weight_cli, \"QUANTIZATION\", {\"q0f16\": object()})\n\n        call_args = {}\n\n        def _fake_convert_weight(**kwargs):\n            call_args.update(kwargs)\n\n        monkeypatch.setattr(convert_weight_cli, \"convert_weight\", _fake_convert_weight)\n\n        convert_weight_cli.main(\n            [\n                str(config_path),\n                \"--quantization\",\n                \"q0f16\",\n                \"--model-type\",\n                \"dummy\",\n                \"--source\",\n                str(source_dir),\n                \"--source-format\",\n                \"auto\",\n                \"--output\",\n                str(output_dir),\n                \"--lora-adapter\",\n                str(adapter_dir),\n            ]\n        )\n\n        assert call_args[\"lora_adapter\"] == adapter_dir\n        assert call_args[\"source\"] == source_index\n        assert call_args[\"source_format\"] == \"huggingface-torch\"\n"
  },
  {
    "path": "tests/python/support/test_convert_weight_lora_merge.py",
    "content": "# pylint: disable=missing-docstring,protected-access\nimport contextlib\nimport json\nimport tempfile\nfrom pathlib import Path\n\nimport pytest\n\nfrom mlc_llm.interface import convert_weight as convert_weight_interface\n\npytestmark = [pytest.mark.unittest]\n\n\ndef test_resolve_base_model_dir():\n    with tempfile.TemporaryDirectory() as tmp_dir:\n        temp_path = Path(tmp_dir)\n        model_dir = temp_path / \"model\"\n        model_dir.mkdir(parents=True, exist_ok=True)\n        source_file = model_dir / \"pytorch_model.bin.index.json\"\n        source_file.write_text(json.dumps({\"weight_map\": {}}), encoding=\"utf-8\")\n\n        assert convert_weight_interface._resolve_base_model_dir(model_dir) == model_dir\n        assert convert_weight_interface._resolve_base_model_dir(source_file) == model_dir\n\n\ndef test_convert_weight_with_lora_uses_merged_source(monkeypatch):\n    with tempfile.TemporaryDirectory() as tmp_dir:\n        temp_path = Path(tmp_dir)\n        config_path = temp_path / \"config.json\"\n        config_path.write_text(json.dumps({}), encoding=\"utf-8\")\n\n        source_dir = temp_path / \"source\"\n        source_dir.mkdir(parents=True, exist_ok=True)\n        source_file = source_dir / \"pytorch_model.bin.index.json\"\n        source_file.write_text(json.dumps({\"weight_map\": {}}), encoding=\"utf-8\")\n\n        adapter_dir = temp_path / \"adapter\"\n        adapter_dir.mkdir(parents=True, exist_ok=True)\n\n        merged_dir = temp_path / \"merged\"\n        merged_dir.mkdir(parents=True, exist_ok=True)\n        merged_file = merged_dir / \"pytorch_model.bin\"\n        merged_file.write_bytes(b\"\")\n\n        captured = {}\n\n        @contextlib.contextmanager\n        def _fake_merge(base_source: Path, lora_adapter: Path):\n            captured[\"merge_base_source\"] = base_source\n            captured[\"merge_lora_adapter\"] = lora_adapter\n            yield merged_dir\n\n        def _fake_detect_weight(weight_path: Path, config_json_path: Path, weight_format: str):\n            captured[\"detect_weight_path\"] = weight_path\n            captured[\"detect_weight_config\"] = config_json_path\n            captured[\"detect_weight_format\"] = weight_format\n            return merged_file, \"huggingface-torch\"\n\n        def _fake_convert_args(args):\n            captured[\"converted_args\"] = args\n\n        monkeypatch.setattr(\n            convert_weight_interface, \"_merge_lora_adapter_with_base_model\", _fake_merge\n        )\n        monkeypatch.setattr(convert_weight_interface, \"detect_weight\", _fake_detect_weight)\n        monkeypatch.setattr(convert_weight_interface, \"_convert_args\", _fake_convert_args)\n        monkeypatch.setattr(convert_weight_interface.ConversionArgs, \"display\", lambda self: None)\n\n        convert_weight_interface.convert_weight(\n            config=config_path,\n            quantization=object(),\n            model=type(\"DummyModel\", (), {\"name\": \"dummy\"})(),\n            device=object(),\n            source=source_file,\n            source_format=\"huggingface-safetensor\",\n            output=temp_path / \"output\",\n            lora_adapter=adapter_dir,\n        )\n\n        converted_args = captured[\"converted_args\"]\n        assert captured[\"merge_base_source\"] == source_file\n        assert captured[\"merge_lora_adapter\"] == adapter_dir\n        assert captured[\"detect_weight_path\"] == merged_dir\n        assert captured[\"detect_weight_config\"] == config_path\n        assert captured[\"detect_weight_format\"] == \"auto\"\n        assert converted_args.source == merged_file\n        assert converted_args.source_format == \"huggingface-torch\"\n        assert converted_args.lora_adapter == adapter_dir\n\n\ndef test_convert_weight_with_lora_rejects_awq():\n    with tempfile.TemporaryDirectory() as tmp_dir:\n        temp_path = Path(tmp_dir)\n        config_path = temp_path / \"config.json\"\n        config_path.write_text(json.dumps({}), encoding=\"utf-8\")\n        adapter_dir = temp_path / \"adapter\"\n        adapter_dir.mkdir(parents=True, exist_ok=True)\n\n        with pytest.raises(ValueError, match=\"only supports source formats\"):\n            convert_weight_interface.convert_weight(\n                config=config_path,\n                quantization=object(),\n                model=type(\"DummyModel\", (), {\"name\": \"dummy\"})(),\n                device=object(),\n                source=temp_path / \"source\",\n                source_format=\"awq\",\n                output=temp_path / \"output\",\n                lora_adapter=adapter_dir,\n            )\n"
  },
  {
    "path": "tests/python/tokenizers/test_streamer.py",
    "content": "\"\"\"Streamer tests in MLC LLM.\n\nPlease specify the local path to llama2 tokenizer via environment\nvariable before running this test.\nThe recommended way to run the tests is to use the following command:\n  MLC_LLAMA_TOKENIZER_PATH=\"path/to/llama/tokenizer\" \\\n  pytest -vv tests/python/support/test_text_streamer_stop_handler.py\n\nHere \"MLC_LLAMA_TOKENIZER_PATH\" can be chosen from\n- a llama2 weight directory (e.g., \"path/to/Llama-2-7b-chat-hf\"),\n- a sentencepiece llama2 tokenizer path\n  (e.g., \"path/to/Llama-2-7b-chat-hf/tokenizer.model\").\n\nTo directly run the Python file (a.k.a., not using pytest), you also need to\nspecify the tokenizer path via environment variable.\n\"\"\"\n\n# pylint: disable=missing-function-docstring\nimport time\nfrom typing import List, Tuple\n\nimport pytest\n\nfrom mlc_llm.testing import require_test_tokenizers\nfrom mlc_llm.tokenizers import StopStrHandler, TextStreamer, Tokenizer\n\n# test category \"unittest\"\npytestmark = [pytest.mark.unittest]\n\n\n# fmt: off\npara_input_tokens = [18585, 29892, 1244, 29915, 29879, 263, 3273, 14880, 1048, 953, 29877, 2397,\n          29892, 988, 1269, 1734, 338, 5643, 491, 385, 953, 29877, 2397, 29901, 13, 13,\n          29950, 1032, 727, 29991, 29871, 243, 162, 148, 142, 306, 29915, 29885, 1244, 304,\n          1371, 1234, 738, 5155, 366, 505, 1048, 953, 29877, 2397, 29871, 243, 162, 167, 151,\n          29889, 7440, 366, 1073, 393, 953, 29877, 2397, 508, 367, 1304, 304, 27769, 23023,\n          1080, 322, 21737, 297, 263, 2090, 322, 1708, 1319, 982, 29973, 29871, 243, 162, 155,\n          135, 2688, 508, 884, 367, 1304, 304, 788, 263, 6023, 310, 2022, 2877, 304, 596, 7191,\n          322, 11803, 29889, 29871, 243, 162, 149, 152, 1126, 29892, 1258, 366, 1073, 393, 727,\n          526, 1584, 953, 29877, 2397, 8090, 322, 14188, 366, 508, 1708, 29973, 29871, 243, 162,\n          145, 177, 243, 162, 148, 131, 1105, 29892, 748, 14432, 322, 679, 907, 1230, 411, 953,\n          29877, 2397, 29991, 29871, 243, 162, 149, 168, 243, 162, 145, 171]\n\nDECODED_PARAGRAPH = (\n    \"Sure, here's a short paragraph about emoji, \"\n    \"where each word is followed by an emoji:\\n\\n\"\n    \"Hey there! 👋 I'm here to help answer any questions you have about emoji 🤔. \"\n    \"Did you know that emoji can be used to convey emotions and feelings in a \"\n    \"fun and playful way? 😄 \"\n    \"They can also be used to add a touch of personality to your messages and posts. 💕 \"\n    \"And, did you know that there are even emoji games and activities you can play? 🎮👀 \"\n    \"So, go ahead and get creative with emoji! 💥🎨\"\n)\n# fmt: on\n\n\n@require_test_tokenizers(\"Llama-2-7b-chat-hf-q4f16_1-MLC\")\ndef test_text_streamer(llama_tokenizer_path: str):  # pylint: disable=redefined-outer-name\n    text_streamer = TextStreamer(Tokenizer(llama_tokenizer_path))\n    total_text = \"\"\n    for token in para_input_tokens:\n        total_text += text_streamer.put([token])\n    total_text += text_streamer.finish()\n\n    assert total_text == DECODED_PARAGRAPH\n\n\ndef stop_handler_process_tokens(\n    stop_handler: StopStrHandler, tokens: List[int], tokenizer: Tokenizer\n) -> str:\n    returned_tokens = []\n    for token in tokens:\n        returned_tokens += stop_handler.put(token)\n        if stop_handler.stop_triggered:\n            break\n\n    if not stop_handler.stop_triggered:\n        returned_tokens += stop_handler.finish()\n\n    return tokenizer.decode(returned_tokens)\n\n\n@require_test_tokenizers(\"Llama-2-7b-chat-hf-q4f16_1-MLC\")\ndef test_stop_str_handler_stop(llama_tokenizer_path: str):  # pylint: disable=redefined-outer-name\n    stop_strs = [\" 🤔\"]\n    tokenizer = Tokenizer(llama_tokenizer_path)\n    stop_handler = StopStrHandler(stop_strs, tokenizer)\n\n    total_text = stop_handler_process_tokens(stop_handler, para_input_tokens, tokenizer)\n    expected_text = (\n        \"Sure, here's a short paragraph about emoji, \"\n        \"where each word is followed by an emoji:\\n\\n\"\n        \"Hey there! 👋 I'm here to help answer any questions you have about emoji\"\n    )\n\n    assert total_text == expected_text\n\n\n@require_test_tokenizers(\"Llama-2-7b-chat-hf-q4f16_1-MLC\")\ndef test_stop_str_handler_not_stop(\n    llama_tokenizer_path: str,  # pylint: disable=redefined-outer-name\n):\n    stop_strs = [\"^^\"]\n    tokenizer = Tokenizer(llama_tokenizer_path)\n    stop_handler = StopStrHandler(stop_strs, tokenizer)\n\n    total_text = stop_handler_process_tokens(stop_handler, para_input_tokens, tokenizer)\n    assert total_text == DECODED_PARAGRAPH\n\n\n@require_test_tokenizers(\"Llama-2-7b-chat-hf-q4f16_1-MLC\")\ndef test_stop_str_handler_return_cached_tokens(\n    llama_tokenizer_path: str,  # pylint: disable=redefined-outer-name\n):\n    tokens = para_input_tokens[:26]  # until \"\\n\\n\"\n    stop_strs = [\"\\n\\n\\n\"]\n    tokenizer = Tokenizer(llama_tokenizer_path)\n    stop_handler = StopStrHandler(stop_strs, tokenizer)\n\n    total_text = stop_handler_process_tokens(stop_handler, tokens, tokenizer)\n    expected_text = (\n        \"Sure, here's a short paragraph about emoji, \"\n        \"where each word is followed by an emoji:\\n\\n\"\n    )\n\n    assert total_text == expected_text\n\n\n@require_test_tokenizers(\"Llama-2-7b-chat-hf-q4f16_1-MLC\")\ndef test_stop_str_handler_throughput(\n    llama_tokenizer_path: str,  # pylint: disable=redefined-outer-name\n):\n    stop_strs = [\"[INST]\"]\n    tokenizer = Tokenizer(llama_tokenizer_path)\n    stop_handler = StopStrHandler(stop_strs, tokenizer)\n\n    tokens = para_input_tokens * 20\n    returned_tokens = []\n\n    tbegin = time.perf_counter()\n    for token in tokens:\n        returned_tokens += stop_handler.put(token)\n        assert not stop_handler.stop_triggered\n    tend = time.perf_counter()\n\n    throughput = len(tokens) / (tend - tbegin)\n    print(\n        f\"num tokens = {len(tokens)}, \"\n        f\"time elapsed = {tend - tbegin:.5f} sec, \"\n        f\"throughput = {throughput}\"\n    )\n    assert throughput >= 100000\n\n\nemoji_tokens_expected_result = [\n    # HF: \"�����\", SentencePiece: \"�👀\"\n    ([177, 243, 162, 148, 131], (\"�����\", \"�👀\")),\n    # Both: \"👀👀\"\n    ([243, 162, 148, 131, 243, 162, 148, 131], (\"👀👀\",)),\n    # Both: \"👀👀👀\"\n    ([243, 162, 148, 131, 243, 162, 148, 131, 243, 162, 148, 131], (\"👀👀👀\",)),\n    # HF: \"👀�������\", SentencePiece: \"👀���👀\"\n    ([243, 162, 148, 131, 162, 148, 131, 243, 162, 148, 131], (\"👀�������\", \"👀���👀\")),\n    # Both: \"👀��� have👀\"\n    ([243, 162, 148, 131, 162, 148, 131, 505, 243, 162, 148, 131], (\"👀��� have👀\",)),\n]\n\n\n@pytest.mark.parametrize(\"tokens_and_results\", emoji_tokens_expected_result)\n@require_test_tokenizers(\"Llama-2-7b-chat-hf-q4f16_1-MLC\")\ndef test_text_streamer_emojis(\n    llama_tokenizer_path: str, tokens_and_results: Tuple[List[int], Tuple[str]]\n):  # pylint: disable=redefined-outer-name\n    text_streamer = TextStreamer(Tokenizer(llama_tokenizer_path))\n    total_text = \"\"\n    tokens, expected_results = tokens_and_results\n    for token in tokens:\n        total_text += text_streamer.put([token])\n    total_text += text_streamer.finish()\n    assert total_text in expected_results\n\n\nif __name__ == \"__main__\":\n    test_text_streamer()\n    test_stop_str_handler_stop()\n    test_stop_str_handler_not_stop()\n    test_stop_str_handler_return_cached_tokens()\n    test_stop_str_handler_throughput()\n\n    for tokens_and_res in emoji_tokens_expected_result:\n        test_text_streamer_emojis(tokens_and_res)\n"
  },
  {
    "path": "version.py",
    "content": "# pylint: disable=missing-docstring\n\"\"\"\nThis is the global script that set the version information of TVM.\nThis script runs and update all the locations that related to versions\n\nList of affected files:\n- mlc-llm-root/pyproject.toml\n\"\"\"\nimport argparse\nimport logging\nimport os\nimport re\nimport subprocess\n\n# Modify the following value during release\n# ---------------------------------------------------\n# Current version:\n# We use the version of the incoming release for code\n# that is under development.\n#\n# It is also fallback version to be used when --git-describe\n# is not invoked, or when the repository does not present the\n# git tags in a format that this script can use.\n#\n# Two tag formats are supported:\n# - vMAJ.MIN.PATCH (e.g. v0.8.0) or\n# - vMAJ.MIN.devN (e.g. v0.8.dev0)\n\n# ---------------------------------------------------\n\n__version__ = \"0.1.dev0\"\nPROJ_ROOT = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))\n\n\ndef py_str(cstr):\n    return cstr.decode(\"utf-8\")\n\n\ndef git_describe_version():\n    \"\"\"Get PEP-440 compatible public and local version using git describe.\n\n    Returns\n    -------\n    pub_ver: str\n        Public version.\n\n    local_ver: str\n        Local version (with additional label appended to pub_ver).\n\n    Notes\n    -----\n    - We follow PEP 440's convention of public version\n      and local versions.\n    - Only tags conforming to vMAJOR.MINOR.REV (e.g. \"v0.7.0\")\n      are considered in order to generate the version string.\n      See the use of `--match` in the `git` command below.\n\n    Here are some examples:\n\n    - pub_ver = '0.7.0', local_ver = '0.7.0':\n      We are at the 0.7.0 release.\n    - pub_ver =  '0.8.dev94', local_ver = '0.8.dev94+g0d07a329e':\n      We are at the 0.8 development cycle.\n      The current source contains 94 additional commits\n      after the most recent tag(v0.7.0),\n      the git short hash tag of the current commit is 0d07a329e.\n    \"\"\"\n    cmd = [\n        \"git\",\n        \"describe\",\n        \"--tags\",\n        \"--match\",\n        \"v[0-9]*.[0-9]*.[0-9]*\",\n        \"--match\",\n        \"v[0-9]*.[0-9]*.dev[0-9]*\",\n    ]\n    with subprocess.Popen(\n        cmd,\n        stdout=subprocess.PIPE,\n        stderr=subprocess.STDOUT,\n        cwd=PROJ_ROOT,\n    ) as proc:\n        (out, _) = proc.communicate()\n\n    if proc.returncode != 0:\n        msg = py_str(out)\n        logging.warning(\"git describe: %s\", msg)\n        return None, None\n    describe = py_str(out).strip()\n    arr_info = describe.split(\"-\")\n\n    # Remove the v prefix, mainly to be robust\n    # to the case where v is not presented as well.\n    if arr_info[0].startswith(\"v\"):\n        arr_info[0] = arr_info[0][1:]\n\n    # hit the exact tag\n    if len(arr_info) == 1:\n        return arr_info[0], arr_info[0]\n\n    if len(arr_info) != 3:\n        logging.warning(\"Invalid output from git describe %s\", describe)\n        return None, None\n\n    dev_pos = arr_info[0].find(\".dev\")\n\n    # Development versions:\n    # The code will reach this point in case it can't match a full release version, such as v0.7.0.\n    #\n    # 1. in case the last known label looks like vMAJ.MIN.devN e.g. v0.8.dev0, we use\n    # the current behavior of just using vMAJ.MIN.devNNNN+gGIT_REV\n    if dev_pos != -1:\n        dev_version = arr_info[0][: arr_info[0].find(\".dev\")]\n    # 2. in case the last known label looks like vMAJ.MIN.PATCH e.g. v0.8.0\n    # then we just carry on with a similar version to what git describe provides, which is\n    # vMAJ.MIN.PATCH.devNNNN+gGIT_REV\n    else:\n        dev_version = arr_info[0]\n\n    pub_ver = f\"{dev_version}.dev{arr_info[1]}\"\n    local_ver = f\"{pub_ver}+{arr_info[2]}\"\n    return pub_ver, local_ver\n\n\n# Implementations\ndef update(file_name, pattern, repl, dry_run=False):\n    update = []\n    hit_counter = 0\n    need_update = False\n    with open(file_name) as file:\n        for l in file:\n            result = re.findall(pattern, l)\n            if result:\n                assert len(result) == 1\n                hit_counter += 1\n                if result[0] != repl:\n                    l = re.sub(pattern, repl, l)\n                    need_update = True\n                    print(\"%s: %s -> %s\" % (file_name, result[0], repl))\n                else:\n                    print(\"%s: version is already %s\" % (file_name, repl))\n\n            update.append(l)\n    if hit_counter != 1:\n        raise RuntimeError(\"Cannot find version in %s\" % file_name)\n\n    if need_update and not dry_run:\n        with open(file_name, \"w\") as output_file:\n            for l in update:\n                output_file.write(l)\n\n\ndef sync_version(pub_ver, local_ver, dry_run):\n    \"\"\"Synchronize version.\"\"\"\n    # pyproject.toml\n    update(\n        os.path.join(PROJ_ROOT, \"pyproject.toml\"),\n        r\"(?<=version = \\\")[.0-9a-z\\+]+\",\n        pub_ver,\n        dry_run,\n    )\n\n\ndef main():\n    logging.basicConfig(level=logging.INFO)\n    parser = argparse.ArgumentParser(description=\"Detect and synchronize version.\")\n    parser.add_argument(\n        \"--print-version\",\n        action=\"store_true\",\n        help=\"Print version to the command line. No changes is applied to files.\",\n    )\n    parser.add_argument(\n        \"--git-describe\",\n        action=\"store_true\",\n        help=\"Use git describe to generate development version.\",\n    )\n    parser.add_argument(\"--dry-run\", action=\"store_true\")\n    pub_ver, local_ver = git_describe_version()\n    opt = parser.parse_args()\n    pub_ver, local_ver = None, None\n    if opt.git_describe:\n        pub_ver, local_ver = git_describe_version()\n    if pub_ver is None:\n        pub_ver = __version__\n    if local_ver is None:\n        local_ver = __version__\n    if opt.print_version:\n        print(local_ver)\n    else:\n        sync_version(pub_ver, local_ver, opt.dry_run)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "web/Makefile",
    "content": "# Licensed to the Apache Software Foundation (ASF) under one\n# or more contributor license agreements.  See the NOTICE file\n# distributed with this work for additional information\n# regarding copyright ownership.  The ASF licenses this file\n# to you under the Apache License, Version 2.0 (the\n# \"License\"); you may not use this file except in compliance\n# with the License.  You may obtain a copy of the License at\n#\n#   http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied.  See the License for the\n# specific language governing permissions and limitations\n# under the License.\n\nTVM_ROOT=$(TVM_SOURCE_DIR)\nMLC_LLM_ROOT=$(shell cd ..; pwd)\n\nINCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\\\n\t-I$(TVM_ROOT)/3rdparty/dlpack/include\\\n\t-I$(TVM_ROOT)/3rdparty/compiler-rt\\\n\t-I$(MLC_LLM_ROOT)/3rdparty/tokenizers-cpp\\\n\t-I$(MLC_LLM_ROOT)/3rdparty/tokenizers-cpp/include -I$(MLC_LLM_ROOT)/cpp\n\n.PHONY: clean all rmtypedep preparetest\n\nall: dist/wasm/mlc_wasm_runtime.wasm\n\nEMCC = emcc\n\nEMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++17 -Wno-ignored-attributes\n\nEMCC_LDFLAGS = --no-entry -s WASM_BIGINT=1 -s ALLOW_MEMORY_GROWTH=1 -s STANDALONE_WASM=1\\\n -s ERROR_ON_UNDEFINED_SYMBOLS=0\n\ndist/wasm/mlc_wasm_runtime.bc: emcc/mlc_wasm_runtime.cc\n\t@mkdir -p $(@D)\n\t$(EMCC) $(EMCC_CFLAGS) -c -MM -MT dist/wasm/mlc_wasm_runtime.bc emcc/mlc_wasm_runtime.cc >dist/wasm/mlc_wasm_runtime.d\n\t$(EMCC) $(EMCC_CFLAGS) -emit-llvm -c -o dist/wasm/mlc_wasm_runtime.bc emcc/mlc_wasm_runtime.cc\n\n# Compile to wasm here so that errors can be caught earlier (rather than during export_library)\ndist/wasm/mlc_wasm_runtime.wasm: dist/wasm/mlc_wasm_runtime.bc\n\t@mkdir -p $(@D)\n\t$(EMCC) $(EMCC_CFLAGS) -o dist/wasm/mlc_wasm_runtime.wasm $+ $(EMCC_LDFLAGS)\n\nclean:\n\t@rm -rf dist/wasm lib\n\n-include dist/wasm/*.d\n"
  },
  {
    "path": "web/README.md",
    "content": "<!--- Licensed to the Apache Software Foundation (ASF) under one -->\n<!--- or more contributor license agreements.  See the NOTICE file -->\n<!--- distributed with this work for additional information -->\n<!--- regarding copyright ownership.  The ASF licenses this file -->\n<!--- to you under the Apache License, Version 2.0 (the -->\n<!--- \"License\"); you may not use this file except in compliance -->\n<!--- with the License.  You may obtain a copy of the License at -->\n\n<!---   http://www.apache.org/licenses/LICENSE-2.0 -->\n\n<!--- Unless required by applicable law or agreed to in writing, -->\n<!--- software distributed under the License is distributed on an -->\n<!--- \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->\n<!--- KIND, either express or implied.  See the License for the -->\n<!--- specific language governing permissions and limitations -->\n<!--- under the License. -->\n\n# MLC-LLM WebAssembly Runtime\n\nThis folder contains MLC-LLM WebAssembly Runtime.\n\nPlease refer to https://llm.mlc.ai/docs/install/emcc.html.\n\nThe main step is running `make` under this folder, a step included in `web/prep_emcc_deps.sh`.\n\n`make` creates `web/dist/wasm/mlc_wasm_runtime.bc`, which will be included in the model library wasm\nwhen we compile the model. Thus during runtime, runtimes like WebLLM can directly reuse source\ncode from MLC-LLM.\n"
  },
  {
    "path": "web/emcc/mlc_wasm_runtime.cc",
    "content": "/*\n * Licensed to the Apache Software Foundation (ASF) under one\n * or more contributor license agreements.  See the NOTICE file\n * distributed with this work for additional information\n * regarding copyright ownership.  The ASF licenses this file\n * to you under the Apache License, Version 2.0 (the\n * \"License\"); you may not use this file except in compliance\n * with the License.  You may obtain a copy of the License at\n *\n *   http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing,\n * software distributed under the License is distributed on an\n * \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n * KIND, either express or implied.  See the License for the\n * specific language governing permissions and limitations\n * under the License.\n */\n\n/*\n * \\file mlc_wasm_runtime.cc\n * \\brief MLC wasm runtime library pack.\n */\n\n// configurations for tvm logging\n#define TVM_LOG_STACK_TRACE 0\n#define TVM_LOG_DEBUG 0\n#define TVM_LOG_CUSTOMIZE 1\n\n// Pass in COMPILE_MLC_WASM_RUNTIME so unsupported code would not be compiled in to the .bc file\n#define COMPILE_MLC_WASM_RUNTIME 1\n#define __STDC_FORMAT_MACROS 1\n#define PICOJSON_USE_INT64\n"
  },
  {
    "path": "web/prep_emcc_deps.sh",
    "content": "#!/bin/bash\n# This file prepares all the necessary dependencies for the web build.\nset -euxo pipefail\n\nemcc --version\nnpm --version\n\nTVM_SOURCE_DIR_SET=\"${TVM_SOURCE_DIR:-}\"\n\ngit submodule update --init --recursive\n\nCURR_DIR=`pwd`\n\nif [[ -z \"${TVM_SOURCE_DIR_SET}\" ]]; then\n    echo \"Do not find TVM_SOURCE_DIR env variable, use 3rdparty/tvm\".\n    echo \"Make sure you set TVM_SOURCE_DIR in your env variable to use emcc build correctly\"\n    export TVM_SOURCE_DIR=\"${TVM_SOURCE_DIR:-${CURR_DIR}/3rdparty/tvm}\"\nfi\n\n# Build mlc_wasm_runtime\ncd web && make\ncd -\n\n# Build tvm's web runtime\ncd ${TVM_SOURCE_DIR}/web && TVM_HOME=${TVM_SOURCE_DIR} make\ncd -\n"
  }
]